From 252329d02aa6edd0832233af3a9e8dd081eaa9ff Mon Sep 17 00:00:00 2001 From: learning-chip Date: Mon, 27 Apr 2026 19:21:59 +0000 Subject: [PATCH 1/7] summarize known gap of python vs cpp (other than multi-push issue) --- examples/aot/flash_attention/known_gap.md | 89 +++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 examples/aot/flash_attention/known_gap.md diff --git a/examples/aot/flash_attention/known_gap.md b/examples/aot/flash_attention/known_gap.md new file mode 100644 index 00000000..83c31a7c --- /dev/null +++ b/examples/aot/flash_attention/known_gap.md @@ -0,0 +1,89 @@ +# Known gap: PTO Python DSL flash attention vs reference C++ + +This document compares the AOT flash-attention builders (`fa_builder.py`, `experimental/fa_builder.py`) and their `ptoas` output to the hand-written reference in `cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA` and helpers). It updates earlier notes on **macro parity** and **`--enable-insert-sync`**. + +## Revised summary + +### What the DSL already mirrors + +- High-level **software pipeline**: QK preload, steady-state interleaving of softmax (lookahead) with GU (current tile), and an **`exp_max` ping-pong ring** when `QK_PRELOAD == 2` (experimental) to match the reference’s “softmax ahead of GU” hazard story. +- **Multi-pipe** QK (cube→vec), P (vec→cube), PV (cube→vec) with GM-backed slots, analogous to the reference’s FIFO staging (different mechanism, same role). + +### Primary performance gaps (largest expected impact) + +1. **Cube tiling and S1 sub-tiling** + Reference: `CUBE_S0 = 128`, `CUBE_S1 = 128`, `TILE_S1 = 256`, so **`kTileFactor = 2`** (two 128-wide K slices per logical 256-wide tile). + DSL (experimental): **`S0 = 32`**, single **`S1_TILE = 256`** matmul per tile. Smaller **M** (32 vs 128) and a different K-splitting strategy typically dominate cube utilization and memory overlap versus the reference. + +2. **Real L1 double-buffering** + Reference: `kMatTNBuffers = 2`, `pMatTNBuffers = 2`, `vMatTNBuffers = 2`, plus vec-side ping-pong (`srcVecTNBuffers`, `xexpVecTNBuffers`, `outOTileNBuffers`). + DSL: cube **single-buffered** tiles with **aliased** `[k_mat_s, k_mat_s]` / same for P/V; **`QK_LOCAL_SLOT_NUM = 1`** on the QK pipe because deeper local slots overflow vec UB. That limits overlap of **TLOAD** with **TMATMUL** compared to the reference. + +3. **Preload depth** + Reference launch uses **`QK_PRELOAD = 4`** (`fa_kernel.cpp`). Experimental DSL uses **`QK_PRELOAD = 2`**. Shallower preload reduces cube/vec overlap on long S1. + +### Macro parity (item 5) — port to Python DSL, not “optional tuning” + +The reference’s hot path is not arbitrary `tile.*` soup; it goes through shared headers that should be **replicated in Python DSL** so lowering and scheduling stay aligned with the tuned C++ path: + +| Reference include | Role | +|-------------------|------| +| [`pto_macro_matmul.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_matmul.hpp) | Cube matmul with **`AccMode`** (e.g. `InitFinalSum` under `UF_ENABLE`), K tiling, and L0-oriented constraints. | +| [`pto_macro_fa_softmax.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_softmax.hpp) | Streaming softmax: `softmax_opt_fa_init_impl` / `softmax_opt_fa_not_init_impl`, scale = `1/sqrt(HEAD)`, **TROWMAX**, **TROWEXPANDSUB**, **TEXP**, **TROWSUM**, **reshape + TCVT** to fp16, causal branches where applicable. | +| [`pto_macro_fa_gu.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_gu.hpp) | **pto_macro_fa_gu** (`TROWEXPANDMUL` + `TADD`), **pto_macro_fa_gu_last** (+ `TROWEXPANDDIV` by `new_global_sum`), **pto_macro_fa_gu_single_and_last_tile**. | + +**Goal:** Express the same ordered primitive sequence (and the same init vs non-init / last-tile branching) in `ptodsl` `tile.*` / `pto.*` APIs—or add thin DSL helpers that document a 1:1 mapping to those macros—so the compiler stack can match the reference kernel’s numerics and fusion expectations. Today’s DSL uses composable `tile.row_max`, `tile.exp`, etc.; they must be **audited and aligned** macro-step by macro-step, not assumed equivalent. + +### `ptoas --enable-insert-sync` (item 6) + +**`--enable-insert-sync` is intentional:** it simplifies generated C++ by having the toolchain insert synchronization. Impact on performance is treated as **minor** relative to **tiling, real double-buffering, and preload depth**. + +Closing the gap should **not** rely on turning sync insertion off; it should rely on **geometry + buffering + macro-faithful lowering** (and any future DSL-level sync refinement if needed). + +### Secondary / structural differences + +- **GM / FFTS / CV:** Reference uses FFTS base, `TSync_Custom`, optional CV comm for many blocks; DSL uses `l2g2l_pipe` / `aic_initialize_pipe` / `aiv_initialize_pipe` and GM slot buffers. Functionally similar staging; details may diverge under high block counts. +- **Benchmark shape parity:** e.g. `experimental/run.py` uses `Q_ROWS` from the builder vs `naive_tpush/run.py` using `s0 = 128*24`; compare throughput with **matched** `Q_ROWS`, `HEAD`, `S1`, and tile counts when isolating kernel quality. + +--- + +## TODO: close the gap (Python DSL ↔ reference C++) + +Use this as a work backlog; order roughly reflects suggested priority (tiling/buffers first, then macro fidelity, then integration). + +### Tiling and cube schedule + +- [ ] **Match reference cube geometry:** `CUBE_S0=128`, `CUBE_S1=128`, `TILE_S1=256`, and **`kTileFactor`** loop (two K slices per 256-wide tile) in the DSL builder’s cube kernel, or justify an equivalent FLOP/memory contract with measurements. +- [ ] **Re-evaluate `S0=32`** (and non-experimental builder constants): target the same per-matmul **M** as the reference unless hardware constraints force otherwise. +- [ ] **Align `QK_PRELOAD`** with the reference launch (**4**) and extend the **`exp_max` / GU ring** logic (or equivalent hazard avoidance) for that depth; assert fifo and UB sizing. + +### Double-buffering and overlap + +- [ ] **Implement true L1 ping-pong** for **K**, **P**, and **V** cube tiles (separate physical buffers, not aliased `[x, x]`). +- [ ] **Vec UB layout:** budget space for **`QK_LOCAL_SLOT_NUM > 1`** if required to mirror reference QK pipe depth, without exceeding UB limits; coordinate with `SLOT_NUM` / GM stride math. +- [ ] **Vec tile banks:** mirror **`srcVecTNBuffers=2`**, **`xexpVecTNBuffers=2`**, **`outOTileNBuffers=2`** (or document why a smaller depth is equivalent). + +### Macro parity in Python (lowering contract) + +- [ ] **Matmul:** Port or wrap **`pto_macro_matmul`** semantics in DSL—especially **`AccMode`** (`Init` / `InitFinalSum` / partial vs final slices) and the **K-subslice** interaction with double-buffered K tiles. +- [ ] **Softmax:** Port **`softmax_opt_fa_init_impl`** and **`softmax_opt_fa_not_init_impl`** (and causal paths if needed) as an explicit sequence of DSL ops matching **TROWMAX → TROWEXPANDSUB → scale → TEXP → TROWSUM → reshape/TCVT** behavior in `pto_macro_fa_softmax.hpp`. +- [ ] **GU:** Port **`pto_macro_fa_gu`**, **`pto_macro_fa_gu_last`**, and **`pto_macro_fa_gu_single_and_last_tile`** as DSL sequences matching **TROWEXPANDMUL / TADD / TROWEXPANDDIV** usage in `pto_macro_fa_gu.hpp`. +- [ ] **Numerics test:** Keep **`torch.testing.assert_close`** (same `rtol`/`atol` as `run.py` / `experimental/run.py`) as the gate after each macro block port. + +### Toolchain and integration + +- [ ] **Keep `--enable-insert-sync`** in `compile.sh` / `experimental/compile.sh`; optimize kernel structure first; only revisit sync policy if profiling shows it dominates after tiling/buffer parity. +- [ ] **Optional:** Parity for **FFTS / CV / `TSync_*`** paths if multi-block or multi-core scaling diverges from reference after cube/vec parity work. +- [ ] **Docs:** When a builder variant reaches parity, record fixed constants (`CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, buffer counts) next to `jit_util_flash.py` / `fa_kernel.cpp` launch parameters so drift is obvious in review. + +--- + +## File map + +| Artifact | Path | +|----------|------| +| Reference kernel | `cpp_ref/naive_tpush/fa_kernel.cpp` | +| JIT constants | `cpp_ref/naive_tpush/jit_util_flash.py` | +| Non-experimental builder | `fa_builder.py` | +| Experimental builder | `experimental/fa_builder.py` | +| PTO FA macros (ISA tree) | `pto-isa-master/kernels/manual/common/flash_atten/pto_macro_*.hpp` | From 93190372a1585f41a04307331bb1ec9bea47ec5a Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Mon, 27 Apr 2026 23:42:12 +0200 Subject: [PATCH 2/7] unsuccessfully attempt --- .../experimental/fa_builder.py | 90 +++++-------------- examples/aot/flash_attention/known_gap.md | 21 ++++- 2 files changed, 40 insertions(+), 71 deletions(-) diff --git a/examples/aot/flash_attention/experimental/fa_builder.py b/examples/aot/flash_attention/experimental/fa_builder.py index 3b178b3e..310f5a33 100644 --- a/examples/aot/flash_attention/experimental/fa_builder.py +++ b/examples/aot/flash_attention/experimental/fa_builder.py @@ -4,20 +4,13 @@ # # This file mirrors the reference C++ scheduler: # -# constexpr int qkPreloadNum = 2; // warmup depth +# constexpr int qkPreloadNum = 2; // warmup depth (reference uses 4; UB-limited here) # -# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec consumes them and -# pushes P[0..QK_PRELOAD-1]. No PV / gu yet. */ +# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec softmaxes them. */ # # /* Steady state, tile_id 0..N-1: -# cube: if (t+QK_PRELOAD < N) compute_qk(t+QK_PRELOAD); -# compute_pv(tile_id); -# vec: if (t+QK_PRELOAD < N) compute_p(t+QK_PRELOAD); -# compute_gu(tile_id); -# so vec's softmax for the LOOK-AHEAD tile fills the QK consumption -# slot WHILE the cube is computing the current PV[t]. The cube -# stops being blocked on a freshly-pushed P (softmax of t+2 has -# already pushed P[t+2] into the FIFO by the time cube needs it). */ +# cube: compute QK[t+QK_PRELOAD] while draining PV[t]; +# vec: interleave gu(t) with softmax(t+QK_PRELOAD) using exp_max a/b. */ # # /* Epilogue: drain the last QK_PRELOAD tiles' PV / gu. */ # @@ -27,6 +20,8 @@ # need a 2-deep ring of `exp_max` tiles (`exp_max_a`, `exp_max_b`). We # implement the ring by unrolling the steady-state loop in pairs of 2 # iterations: even iters use `exp_max_a`, odd iters use `exp_max_b`. +# (Raising preload to 4 like the C++ launch requires a 4-slot ring and more +# VEC UB headroom; a prototype hit CCU address faults until layout is reworked.) # # Other state in the softmax (`new_global_max`, `new_global_sum`) does # NOT need a ring: it is monotonic accumulator state across all tiles @@ -64,11 +59,9 @@ Q_ROWS = 2048 NUM_Q_BLOCKS = Q_ROWS // S0 # 64 row-blocks -# QK preload depth — must be >= 1; reference uses 2. The vec pre-softmaxes -# tiles 0..QK_PRELOAD-1, then the steady-state loop interleaves softmax(t+QK_PRELOAD) -# with gu(t), and the epilogue drains the last QK_PRELOAD gu's. -# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled to -# ping-pong the exp_max ring (see below). +# QK preload depth — must be >= 1; reference launch uses 4, this builder +# keeps 2 for a smaller VEC exp_max ring (see header comment). +# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled. QK_PRELOAD = 2 assert ( NUM_TILES - QK_PRELOAD @@ -229,7 +222,6 @@ def cube_kernel( ) -> None: c0 = const(0) c1 = const(1) - c2 = const(2) cS0 = const(S0) cHEAD = const(HEAD) cS1_TILE = const(S1_TILE) @@ -297,11 +289,8 @@ def cube_kernel( nosplit=False, ) - # All cube tile-buffers are single-buffered. K and V share RIGHT - # storage: for HEAD=128, S1_TILE=256 each RIGHT tile is exactly - # 64 KB, and the schedule uses V for PV before moving K for QK. - # This mirrors the hand-written reference's explicit local-memory - # assignment style and avoids asking RIGHT for two full tiles. + # K and V share one RIGHT bank (sequential use); `[buf]` lists alias + # to the same physical tile for each role (schedule is still safe). right_base = const(RIGHT_KV_OFF, s.int64) q_mat = pto.alloc_tile(q_mat_ty, addr=const(MAT_Q_OFF, s.int64)) q_left = pto.alloc_tile(q_left_ty, addr=const(LEFT_Q_OFF, s.int64)) @@ -313,8 +302,6 @@ def cube_kernel( v_mat_s = pto.alloc_tile(v_mat_ty, addr=const(MAT_V_OFF, s.int64)) v_right_s = pto.alloc_tile(v_right_ty, addr=right_base) pv_acc_s = pto.alloc_tile(pv_acc_ty, addr=const(ACC_PV_OFF, s.int64)) - # Aliasing wrappers: keep the per-iteration `[buf]` indexing pattern - # in the body even though all slots currently point at one alloc. k_mat = [k_mat_s, k_mat_s] k_right = [k_right_s, k_right_s] qk_acc = [qk_acc_s, qk_acc_s] @@ -351,8 +338,6 @@ def cube_kernel( tile.mov(q_mat, q_left) # =================== Cube prologue: emit QK[0..QK_PRELOAD-1] =================== - # Each prologue QK uses its own k_mat / k_right / qk_acc slot - # so MTE2 load of K[1] overlaps the M of QK[0]. for k in range(QK_PRELOAD): k_off = const(k * S1_TILE) kt_view_k = pto.slice_view( @@ -376,16 +361,8 @@ def cube_kernel( pto.load(v_view_0, v_mat[0]) # =================== Cube steady state =================== - # Pair-unrolled. Iter t (parity = t%2 → buffer index `b`): - # * load K[next_qk = t+QK_PRELOAD] into k_mat[b] - # (next_qk parity equals t parity since QK_PRELOAD == 2) - # * pop / mov P[t] into p_left[b]; mov V[t] (in v_mat[b]) → v_right[b] - # * preload V[t+1] into v_mat[1-b] - # * matmul PV[t] into pv_acc[b]; push - # * matmul QK[next_qk] into qk_acc[b]; push - # Pair handler: + # Pair-unrolled; buffer index b = t % 2 (logical ping-pong). def emit_cube_step(t_idx, b): - # next_qk = t_idx + QK_PRELOAD (only used when in main range) next_qk = t_idx + const(QK_PRELOAD) kt_off = next_qk * cS1_TILE kt_view = pto.slice_view( @@ -418,18 +395,16 @@ def emit_cube_step(t_idx, b): pto.tpush(qk_acc[b], qk_pipe, SPLIT_UP_DOWN) assert (NUM_TILES - QK_PRELOAD) % 2 == 0 - for p in pto.range(c0, const((NUM_TILES - QK_PRELOAD) // 2), c1): + c2 = const(2) + for p in pto.range(c0, const(STEADY_PAIRS), c1): t_a = p * c2 emit_cube_step(t_a, 0) t_b = p * c2 + c1 emit_cube_step(t_b, 1) # =================== Cube epilogue: drain last QK_PRELOAD PVs =================== - # Tile_id range: NUM_TILES-QK_PRELOAD .. NUM_TILES-1. - # NUM_TILES is even and QK_PRELOAD is even, so the first epilogue - # tile has parity 0. v_mat[0] holds V[NUM_TILES-QK_PRELOAD] thanks - # to the last steady-state preload (it loaded V[t_b+1] = V[NUM_TILES-QK_PRELOAD] - # into v_mat[1-1]=v_mat[0]). + # Tile ids NUM_TILES-QK_PRELOAD .. NUM_TILES-1; v_mat parity matches + # the last steady-state V preload pattern (same as QK_PRELOAD==2 case). for k in range(QK_PRELOAD): b = k % 2 p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) @@ -461,7 +436,6 @@ def vector_kernel( ) -> None: c0 = const(0) c1 = const(1) - c2 = const(2) cS0 = const(S0) cS0_HALF = const(S0_HALF) cHEAD = const(HEAD) @@ -547,13 +521,7 @@ def vector_kernel( red_ty, addr=const(VEC_NEW_GLOBAL_SUM_OFF, s.int64) ) local_sum = pto.alloc_tile(red_ty, addr=const(VEC_LOCAL_SUM_OFF, s.int64)) - # Ring of QK_PRELOAD exp_max tiles. With QK_PRELOAD=2 we use a/b - # ping-pong: even-parity tiles use exp_max_a, odd-parity tiles use - # exp_max_b. softmax(t) writes the exp_max for tile t into the - # corresponding slot; gu(t) reads it from the same slot. Because - # softmax(t+QK_PRELOAD) and gu(t) hit the SAME slot (parity matches), - # the steady-state loop must do gu(t) BEFORE softmax(t+QK_PRELOAD) - # to avoid clobbering. + # Two exp_max tiles (a/b). Interleave gu and softmax to avoid clobber. assert QK_PRELOAD == 2, "exp_max ring is hard-coded to 2 tiles" exp_max_a = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_A_OFF, s.int64)) exp_max_b = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_B_OFF, s.int64)) @@ -625,29 +593,15 @@ def emit_gu_step(exp_max_slot, is_init): o_row_off = qb * cS0 # =================== Vec prologue: softmax(0..QK_PRELOAD-1) =================== - # softmax(0): is_init=True (writes exp_max_a, but exp_max_a for tile 0 - # is unused by gu(0) — gu(0) takes the init branch and just movs PV. - # Still we must compute it correctly; the init branch doesn't touch exp_max. emit_softmax_step(exp_max_a, is_init=True) - # softmax(1): is_init=False (writes exp_max_b) emit_softmax_step(exp_max_b, is_init=False) # =================== Vec steady state =================== - # Pair-unrolled: each `p` iteration handles tiles t_a = 2p, t_b = 2p+1. - # gu(t_a) reads exp_max_a (set by softmax(t_a) earlier); - # softmax(t_a+2) writes exp_max_a (matches parity). - # gu(t_b) reads exp_max_b; softmax(t_b+2) writes exp_max_b. - # CRITICAL: gu BEFORE softmax in same step to avoid clobbering. - # - # First pair (p=0, t_a=0, t_b=1) is Python-unrolled so we can - # take the `is_init=True` branch on gu(0) (which initializes - # o_tile via mov rather than rescale+add). - emit_gu_step(exp_max_a, is_init=True) # tile 0 - emit_softmax_step(exp_max_a, is_init=False) # tile 2 → exp_max_a - emit_gu_step(exp_max_b, is_init=False) # tile 1 - emit_softmax_step(exp_max_b, is_init=False) # tile 3 → exp_max_b - - # Remaining pairs (p=1..STEADY_PAIRS-1) inside a runtime loop. + emit_gu_step(exp_max_a, is_init=True) + emit_softmax_step(exp_max_a, is_init=False) + emit_gu_step(exp_max_b, is_init=False) + emit_softmax_step(exp_max_b, is_init=False) + for p in pto.range(c1, const(STEADY_PAIRS), c1): emit_gu_step(exp_max_a, is_init=False) emit_softmax_step(exp_max_a, is_init=False) diff --git a/examples/aot/flash_attention/known_gap.md b/examples/aot/flash_attention/known_gap.md index 83c31a7c..6f9ee33f 100644 --- a/examples/aot/flash_attention/known_gap.md +++ b/examples/aot/flash_attention/known_gap.md @@ -20,7 +20,7 @@ This document compares the AOT flash-attention builders (`fa_builder.py`, `exper DSL: cube **single-buffered** tiles with **aliased** `[k_mat_s, k_mat_s]` / same for P/V; **`QK_LOCAL_SLOT_NUM = 1`** on the QK pipe because deeper local slots overflow vec UB. That limits overlap of **TLOAD** with **TMATMUL** compared to the reference. 3. **Preload depth** - Reference launch uses **`QK_PRELOAD = 4`** (`fa_kernel.cpp`). Experimental DSL uses **`QK_PRELOAD = 2`**. Shallower preload reduces cube/vec overlap on long S1. + Reference launch uses **`QK_PRELOAD = 4`** (`fa_kernel.cpp`). Experimental DSL uses **`QK_PRELOAD = 2`**. Shallower preload reduces cube/vec overlap on long S1. A DSL port to 4 needs a **4-deep `exp_max` ring** and more VEC space; an attempt faulted until UB layout (recv scratch, `MAT_P_FIFO` tail) is redesigned. ### Macro parity (item 5) — port to Python DSL, not “optional tuning” @@ -47,14 +47,29 @@ Closing the gap should **not** rely on turning sync insertion off; it should rel --- +## Progress log (experimental `fa_builder.py`, Apr 2026) + +Measured on NPU via `experimental/run.py` (Q=2048, H=128, S1_TILE=256): kernel holds ~24–26 TFLOP/s vs ~60+ TFLOP/s for `torch_npu` fused ref on the same script; correctness (`assert_close` at `run.py:151`) remains the gate. + +| Change attempted | Result | +|------------------|--------| +| `QK_PRELOAD=4` + four `exp_max` slots + quad-unrolled vec/cube + true dual MAT banks for K/P/V | AICore CCU address fault (`mte`/`ccu`); likely VEC `tpop` scratch / expanded red region + dual `RIGHT` typing; **reverted**. | +| True L1 ping-pong (separate `MAT_K0`/`MAT_K1`, …) without preload-4 | Overlapped `MAT_P_FIFO` with tail tiles until `MAT_P_FIFO_OFF` was recomputed; still faulted with dual `RIGHT` `alloc_tile` until fully reverted to aliased single-buffer layout. | +| K-split: two `CUBE_S1=128` matmuls per tile via `tile.subview` on `qk_acc` | Builds and passes `assert_close`; **~7% slower** than one `S1_TILE=256` matmul on this target — **reverted**. | +| Reorder steady cube step (PV before K load) | **Slight regression** vs original order on sampled runs — **reverted**. | + +**Takeaway:** Dominant gap remains **cube M (`S0=32` vs 128)** and **preload/ring depth**; raising `S0` or `QK_PRELOAD` bumps VEC/CUBE footprint and needs a audited memory map (recv scratch ≥ one `qk_vec` tile after `exp_max` slots, `MAT_P_FIFO` after all MAT tiles). + +--- + ## TODO: close the gap (Python DSL ↔ reference C++) Use this as a work backlog; order roughly reflects suggested priority (tiling/buffers first, then macro fidelity, then integration). ### Tiling and cube schedule -- [ ] **Match reference cube geometry:** `CUBE_S0=128`, `CUBE_S1=128`, `TILE_S1=256`, and **`kTileFactor`** loop (two K slices per 256-wide tile) in the DSL builder’s cube kernel, or justify an equivalent FLOP/memory contract with measurements. -- [ ] **Re-evaluate `S0=32`** (and non-experimental builder constants): target the same per-matmul **M** as the reference unless hardware constraints force otherwise. +- [ ] **Match reference cube geometry:** `CUBE_S0=128`, `CUBE_S1=128`, `TILE_S1=256`, and **`kTileFactor`** loop (two K slices per 256-wide tile) in the DSL builder’s cube kernel, or justify an equivalent FLOP/memory contract with measurements. *(Prototype K-split only: numerics OK, throughput down on current NPU.)* +- [ ] **Re-evaluate `S0=32`** (and non-experimental builder constants): target the same per-matmul **M** as the reference unless hardware constraints force otherwise. *(VEC `SLOT_SIZE_QK` scales with `S0`; `S0=64` overflowed UB in a back-of-envelope layout.)* - [ ] **Align `QK_PRELOAD`** with the reference launch (**4**) and extend the **`exp_max` / GU ring** logic (or equivalent hazard avoidance) for that depth; assert fifo and UB sizing. ### Double-buffering and overlap From 385081465e778f72568f59d0be8b34506ac127d0 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 28 Apr 2026 00:31:03 +0200 Subject: [PATCH 3/7] ptoas requests --- .../experimental/fa_builder.py | 24 +++- examples/aot/flash_attention/known_gap.md | 29 +++- examples/aot/flash_attention/ptoas_request.md | 133 ++++++++++++++++++ 3 files changed, 176 insertions(+), 10 deletions(-) create mode 100644 examples/aot/flash_attention/ptoas_request.md diff --git a/examples/aot/flash_attention/experimental/fa_builder.py b/examples/aot/flash_attention/experimental/fa_builder.py index 310f5a33..664ecc46 100644 --- a/examples/aot/flash_attention/experimental/fa_builder.py +++ b/examples/aot/flash_attention/experimental/fa_builder.py @@ -44,8 +44,11 @@ # --------------------------------------------------------------------------- # Static shapes (must match run.py constants) # --------------------------------------------------------------------------- -S0 = 32 # Q rows per block -S0_HALF = S0 // 2 # rows per AIV sub-block +# Match reference `CUBE_S0 = 128` (`fa_kernel.cpp`). Vec UB grows with +# `S0` because the QK/PV pipe mirrors full tiles in UB; override with FA_S0 +# if a smaller block is needed while tooling catches up (see `known_gap.md`). +S0 = int(os.environ.get("FA_S0", "128")) +S0_HALF = S0 // 2 # rows per AIV sub-block (TILE_UP_DOWN split) HEAD = 128 # attention head dimension S1_TILE = 256 # K/V columns per tile # NUM_TILES is overridable via the FA_NUM_TILES env var so the same builder @@ -57,7 +60,7 @@ S1_TOTAL = S1_TILE * NUM_TILES Q_ROWS = 2048 -NUM_Q_BLOCKS = Q_ROWS // S0 # 64 row-blocks +NUM_Q_BLOCKS = Q_ROWS // S0 # e.g. 16 when Q_ROWS=2048 and S0=128 # QK preload depth — must be >= 1; reference launch uses 4, this builder # keeps 2 for a smaller VEC exp_max ring (see header comment). @@ -69,7 +72,8 @@ STEADY_PAIRS = (NUM_TILES - QK_PRELOAD) // 2 # Per-pipe slot sizes (bytes). -SLOT_SIZE_QK = S0 * S1_TILE * 4 # fp32 QK accumulator +# QK: fp32 on the wire (cube `tpush` only accepts ACC tiles today — see known_gap). +SLOT_SIZE_QK = S0 * S1_TILE * 4 SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator SLOT_SIZE_P = S0 * S1_TILE * 2 # fp16 P matrix sent vec → cube @@ -99,7 +103,10 @@ MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2 MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2 MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2 -MAT_P_FIFO_OFF = 262144 +MAT_P_FIFO_OFF = MAT_V_OFF + S1_TILE * HEAD * 2 +# Pad past the last MAT-resident tile; bisheng is sensitive to overlap here. +if MAT_P_FIFO_OFF < 393216: + MAT_P_FIFO_OFF = 393216 LEFT_Q_OFF = 0 LEFT_P_OFF = LEFT_Q_OFF + S0 * HEAD * 2 @@ -116,13 +123,16 @@ VEC_P_FP16_OFF = VEC_P_FP32_OFF + S0_HALF * S1_TILE * 4 VEC_O_OFF = VEC_P_FP16_OFF + S0_HALF * S1_TILE * 2 VEC_RED_BASE_OFF = VEC_O_OFF + S0_HALF * HEAD * 4 -VEC_RED_STRIDE = 512 +# Tight packing for reduce / exp_max ring scalars (one column per logical row). +VEC_RED_STRIDE = ((S0_HALF * 4 + 127) // 128) * 128 VEC_NEW_GLOBAL_MAX_OFF = VEC_RED_BASE_OFF + 0 * VEC_RED_STRIDE VEC_LOCAL_MAX_OFF = VEC_RED_BASE_OFF + 1 * VEC_RED_STRIDE VEC_NEW_GLOBAL_SUM_OFF = VEC_RED_BASE_OFF + 2 * VEC_RED_STRIDE VEC_LOCAL_SUM_OFF = VEC_RED_BASE_OFF + 3 * VEC_RED_STRIDE VEC_EXP_MAX_A_OFF = VEC_RED_BASE_OFF + 4 * VEC_RED_STRIDE VEC_EXP_MAX_B_OFF = VEC_RED_BASE_OFF + 5 * VEC_RED_STRIDE +# Shared recv scratch: max(fp16 QK half-tile, fp32 PV half-tile) for tpop addr=. +_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 2, S0_HALF * HEAD * 4) VEC_RECV_OFF = VEC_RED_BASE_OFF + 6 * VEC_RED_STRIDE ID_QK = 10 # Cube → Vec, dir_mask = 1 (uses lower-level l2g2l) @@ -545,7 +555,7 @@ def emit_softmax_step(exp_max_slot, is_init): addr=const(VEC_RECV_OFF, s.int64), ) tile.muls(qk_recv, scale, qk_recv) - tile.row_max(qk_recv, tmp_tile, local_max) + tile.row_max(qk_recv, p_fp32, local_max) local_max_r = tile.reshape(red_row_ty, local_max) new_global_max_r = tile.reshape(red_row_ty, new_global_max) diff --git a/examples/aot/flash_attention/known_gap.md b/examples/aot/flash_attention/known_gap.md index 6f9ee33f..e58e722b 100644 --- a/examples/aot/flash_attention/known_gap.md +++ b/examples/aot/flash_attention/known_gap.md @@ -13,7 +13,7 @@ This document compares the AOT flash-attention builders (`fa_builder.py`, `exper 1. **Cube tiling and S1 sub-tiling** Reference: `CUBE_S0 = 128`, `CUBE_S1 = 128`, `TILE_S1 = 256`, so **`kTileFactor = 2`** (two 128-wide K slices per logical 256-wide tile). - DSL (experimental): **`S0 = 32`**, single **`S1_TILE = 256`** matmul per tile. Smaller **M** (32 vs 128) and a different K-splitting strategy typically dominate cube utilization and memory overlap versus the reference. + DSL (experimental): **`S0 = 128`** by default (env `FA_S0`), still a **single** **`S1_TILE = 256`** matmul per tile (no K-split). The row-block size now matches reference **M**; the remaining gap is **K micro-tiling / matmul overlap** versus the reference’s two `Cube_S1` passes per logical tile. 2. **Real L1 double-buffering** Reference: `kMatTNBuffers = 2`, `pMatTNBuffers = 2`, `vMatTNBuffers = 2`, plus vec-side ping-pong (`srcVecTNBuffers`, `xexpVecTNBuffers`, `outOTileNBuffers`). @@ -58,7 +58,30 @@ Measured on NPU via `experimental/run.py` (Q=2048, H=128, S1_TILE=256): kernel h | K-split: two `CUBE_S1=128` matmuls per tile via `tile.subview` on `qk_acc` | Builds and passes `assert_close`; **~7% slower** than one `S1_TILE=256` matmul on this target — **reverted**. | | Reorder steady cube step (PV before K load) | **Slight regression** vs original order on sampled runs — **reverted**. | -**Takeaway:** Dominant gap remains **cube M (`S0=32` vs 128)** and **preload/ring depth**; raising `S0` or `QK_PRELOAD` bumps VEC/CUBE footprint and needs a audited memory map (recv scratch ≥ one `qk_vec` tile after `exp_max` slots, `MAT_P_FIFO` after all MAT tiles). +**Takeaway:** With **`S0=128`** landed in the experimental builder, the largest remaining structural gaps versus the reference are **`kTileFactor` / `CUBE_S1` K-split**, **preload / ring depth (`QK_PRELOAD`, CV FIFO)**, and **vec working-tile geometry (`Vec_S0`)** relative to what `l2g2l_pipe` + `TILE_UP_DOWN` imply for UB. Raising `QK_PRELOAD` still needs more `exp_max` slots and recv/GM budget. + +| `S0=128` (Apr 2026) | Default `S0` raised to **128** (`FA_S0`). `pto.tpush` from cube still requires **ACC** tiles (PTO verifier: only `AddressSpace::ACC` maps to a producer pipe); staging QK as **fp16 on MAT/LEFT** before push was rejected at MLIR verify, so QK stays **fp32 on the wire** with full `SLOT_SIZE_QK`. Vec softmax reuses **`p_fp32` as `row_max` scratch** (same lifetime as before `row_expand_sub`) plus a **single shared `VEC_RECV_OFF`** sized for the larger of QK/PV half-tiles. `experimental/run.py` + `compile.sh` pass on NPU at ~24 TFLOP/s (unchanged order-of-magnitude vs fused ref). | + +--- + +## PTOAS / PTO dialect / Python binding — feature requests (algorithm parity) + +These are the main **toolchain** gaps noticed while aligning `experimental/fa_kernel` with `cpp_ref/naive_tpush/fa_kernel.cpp`. They are not criticisms of the hand-written reference; they are concrete asks so the **same algorithm config** (tiling, dtypes on wires, vec working set) can be expressed without fighting verifiers or UB. + +1. **C2V `pto.tpush` producer tiles beyond ACC** + Today `TPushOp::getPipe()` maps **only** `AddressSpace::ACC` → `PIPE_FIX` (see `PTOOps.td`); **MAT** and **LEFT** producers yield `PIPE_UNASSIGNED` and fail verification. The reference keeps QK in **fp32 in GM** (`qk_tile_fifo`) and uses **fp16** only inside vec macros (`TileDataH_T`, `TCVT`). A natural DSL port would **cvt** `TileAcc` → **`Tile`** and `tpush` that tile to halve **`slot_size`** / vec FIFO pressure. **Ask:** allow **fp16 (and/or LEFT/MAT) tiles** as legal C2V `tpush` sources when `slot_size` matches, or document the intended lowering (e.g. MTE path) so Python does not need ACC-only staging. + +2. **Decouple `slot_size` from “one full cube row tile” for vec UB accounting** + Reference **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (e.g. **32** rows × **256** cols in vec UB) while GM still holds **`Cube_S0 × Tile_S1`** floats per logical tile, assembled from **`kTileFactor`** slices of **`Cube_S0 × Cube_S1`**. The DSL **`l2g2l_pipe`** ties **vec `reserve_buffer`** size to **`SLOT_SIZE_QK`** and **`tpop`** delivers **`S0_HALF × S1_TILE`** per subblock. **Ask:** first-class **“logical tile vs wire chunk”** (multi-slot per tile_id, or column-strip `tpop` into a fixed vec workspace) so vec UB tracks **`Vec_S0`** like the C++ launch, not **`Cube_S0/2`** per `TILE_UP_DOWN` alone. + +3. **`kTileFactor` / K-split + softmax without a single 64×256 vec tile** + Matching the reference requires **multiple `compute_p` / `row_slice` passes** per tile and **partial QK layout in GM** (`base_elems + row_offset * Cube_S1`). **Ask:** DSL helpers or ops for **GM strided views** + **event sync** equivalent to `TSync_Custom` / `qk2smSync`, or **documented** mapping from `initialize_l2g2l_pipe` + `tpop` to that pattern so cube can emit **128×128** stores while vec runs **32×256** softmax without holding a **64×256** `qk_vec` buffer per subblock. + +4. **`QK_PRELOAD = 4` and deeper CV FIFOs** + Reference uses **`qkPreloadNum = 4`** with **`l1_exp_max_ififo[qkp_tile_fifo_size]`**. DSL stays at **`QK_PRELOAD = 2`** for a smaller **`exp_max` ring**. **Ask:** either **lowered UB cost** for pipe rings (item 1–2) or **optional GM-backed vec inputs** so preload depth can match the C++ launch without manual byte arithmetic. + +5. **Python binding ergonomics** + **Ask:** optional **computed layout** (or static asserts) from tensor shapes for **MAT / VEC base offsets** so raising `S0` cannot silently overlap **`MAT_P_FIFO`** with cube tiles; and a **single knob** mirroring `runTFA` template parameters (`CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, CV FIFO depth) mapped to **`S0`**, **`S1_TILE`**, **`QK_PRELOAD`**, and pipe **`slot_num` / `local_slot_num`**. --- @@ -69,7 +92,7 @@ Use this as a work backlog; order roughly reflects suggested priority (tiling/bu ### Tiling and cube schedule - [ ] **Match reference cube geometry:** `CUBE_S0=128`, `CUBE_S1=128`, `TILE_S1=256`, and **`kTileFactor`** loop (two K slices per 256-wide tile) in the DSL builder’s cube kernel, or justify an equivalent FLOP/memory contract with measurements. *(Prototype K-split only: numerics OK, throughput down on current NPU.)* -- [ ] **Re-evaluate `S0=32`** (and non-experimental builder constants): target the same per-matmul **M** as the reference unless hardware constraints force otherwise. *(VEC `SLOT_SIZE_QK` scales with `S0`; `S0=64` overflowed UB in a back-of-envelope layout.)* +- [x] **Match reference `CUBE_S0` (128) in experimental builder** — default `S0=128` via `FA_S0` (Apr 2026); UB layout was tightened (shared `tpop` recv sizing, `row_max` scratch reuse, smaller `VEC_RED_STRIDE`). Smaller blocks remain available with `FA_S0=32` etc. if needed. - [ ] **Align `QK_PRELOAD`** with the reference launch (**4**) and extend the **`exp_max` / GU ring** logic (or equivalent hazard avoidance) for that depth; assert fifo and UB sizing. ### Double-buffering and overlap diff --git a/examples/aot/flash_attention/ptoas_request.md b/examples/aot/flash_attention/ptoas_request.md new file mode 100644 index 00000000..039415bf --- /dev/null +++ b/examples/aot/flash_attention/ptoas_request.md @@ -0,0 +1,133 @@ +# PTOAS feature requests (PTO MLIR dialect + Python bindings) + +This document collects **actionable requests** for the PTOAS / PTO dialect stack so that **flash-attention–style kernels** written in Python (e.g. `examples/aot/flash_attention/experimental/fa_builder.py` via `ptodsl`) can **closely match** the hand-tuned reference in `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA`, `compute_qk`, `compute_p`, `compute_pv`, `compute_gu`). + +Upstream sources referenced below live under: + +`.agent/skills/translate_cpp2py/references/external_repo/PTOAS/` + +--- + +## 1. Allow cube-side `pto.tpush` from non-ACC tiles (C2V producer coverage) + +**Problem.** `TPushOp::getPipe()` only maps **`AddressSpace::ACC`** to a concrete pipe (`PIPE_FIX`). **`MAT`** and **`LEFT`** tiles map to **`PIPE_UNASSIGNED`**, so MLIR verification fails with *“tile type must map to a supported producer pipe”* when attempting to push an fp16 staging tile (e.g. post-`TCvt` from acc) over a C2V pipe. + +**Evidence.** `include/PTO/IR/PTOOps.td`, `TPushOp` `getPipe()` (lines ~1767–1792): only `ACC` and `VEC` branches; all other address spaces return `PIPE_UNASSIGNED`. + +**Motivation (ref FA).** The C++ reference keeps **fp32 QK in GM** and uses **fp16** inside vec macros for softmax output / P staging. A Python port naturally wants **cvt(acc f32 → mat/left f16) → tpush** to **halve `slot_size`** and vec FIFO pressure while keeping matmul in fp32. + +**Ask.** + +- Extend **`TPushOp`** (and verifier / lowering to EmitC) so **cube producers** can legally push **`TileBufType` in `MAT` and/or `LEFT`** with dtypes compatible with the pipe’s `slot_size`, **or** +- Document and implement an **official lowering path** (e.g. implicit MTE move acc→staging then push) so frontends do not need to guess unsupported combinations. + +--- + +## 2. Decouple `slot_size` (wire bytes) from producer/consumer tile element type + +**Problem.** `initialize_l2g2l_pipe` takes a single **`slot_size` (bytes)** while `tpush`/`tpop` tile types carry **dtype + shape**. Today authors must keep **manual consistency** between `SLOT_SIZE_QK`, cube `TileAcc`, and vec `Tile`; there is no first-class “**fp32 compute, fp16 wire**” contract. + +**Evidence.** `InitializeL2G2LPipeOp` in `PTOOps.td` (~1681–1712): `slot_size` is a plain `i32`; pipe init does not encode logical vs physical width. + +**Motivation (ref FA).** Reference layout uses **`sizeof(float)` × Cube_S0 × Tile_S1`** in GM for `qk_tile_fifo`, while vec tiles are **`Vec_S0 × Tile_S1`** with **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`**. The toolchain should help express **logical tile**, **wire format**, and **vec working tile** without ad-hoc byte math in Python. + +**Ask.** + +- Optional attributes on **`initialize_l2g2l_pipe`** (or companion op) for **`wire_elem_type`**, **`logical_shape`**, and/or **`vec_slice_shape`**, validated against `slot_size`, **or** +- A small **tablegen-verified** bundle type for “pipe slot descriptor” consumed by both cube and vec builders. + +--- + +## 3. First-class **K-split** (`kTileFactor`) and **partial QK** delivery to vec + +**Problem.** The reference runs **`kTileFactor = Tile_S1 / Cube_S1`** cube passes (e.g. two **128×128** matmuls per **256**-wide logical tile), stores **`Cube_S0 × Cube_S1`** slices into GM, and vec **`compute_p`** performs **`kTileFactor`** **TLOAD**s of **`Vec_S0 × Cube_S1`** into a **`Vec_S0 × Tile_S1`** vec tile. The Python + `l2g2l_pipe` path instead tends toward **one full `Cube_S0 × Tile_S1` tpush** and a **`S0_HALF × S1_TILE` tpop**, which inflates **vec UB** versus **`Vec_S0 × Tile_S1`**. + +**Motivation (ref FA).** Matching **`CUBE_S1`**, **`kTileFactor`**, and **`Vec_S0`** is required for both **numerics/scheduling parity** and **UB parity** with `fa_kernel.cpp`. + +**Ask.** + +- Either **documented** lowering from “ref-style GM layout + sync” to **`initialize_l2g2l_pipe` + `tpush`/`tpop`**, **or** new ops / pipe modes for: + - **multiple ordered `tpush`es** per logical `tile_id` with **fixed GM packing** matching the reference’s `base_elems` formulas, and + - **vec-side assembly** (`tpop` into column sub-ranges of one vec tile, or explicit `tassign`/`subview` at UB addresses) without requiring a single oversized **`tpop`** result tile. + +--- + +## 4. Richer **`split`** / subblock model (beyond one `TILE_UP_DOWN` halving) + +**Problem.** `split` on `tpush`/`tpop` models a **single** split axis enum; reference logic combines **`get_subblockid()`**, **`row_slice`**, and **`kTileFactor`** to address **four** distinct **32-row** bands across **`Cube_S0 = 128`**. Expressing that with only **one** up/down split per op forces **larger per-core vec tiles** than the reference. + +**Evidence.** Design notes in `docs/designs/ptoas-tpush-tpop-design.md` (split semantics); reference `compute_p` row/col slicing in `fa_kernel.cpp`. + +**Ask.** + +- Consider **documented composition** of splits (e.g. nested phases) **or** additional split modes / **multi-phase tpop** that align with **`row_slice × subblock`** patterns used in FA macros. + +--- + +## 5. **`local_slot_num` / vec `reserve_buffer`** vs GM-only consumer patterns + +**Problem.** `local_slot_num` must be **> 0** and `local_addr` is mandatory for `initialize_l2g2l_pipe` (verifier in `PTO.cpp` / design doc §5.2). The reference often behaves like **“cube writes GM; vec reads GM after sync”** with **smaller vec-local FIFOs** (`srcVecTNBuffers`, etc.), not necessarily a full **local mirror** of every slot byte in UB. + +**Evidence.** `docs/designs/ptoas-tpush-tpop-design.md` (~318–361, ~759–761). + +**Ask.** + +- Optional **GM-primary consumer** mode: vec **`tpop`** semantics that **do not** require **`reserve_buffer(slot_size × local_slot_num)`** when the consumer only needs a **bounded scratch** (with **verified** max live bytes), **or** +- A **`tpop_from_gm` / `wait_slot` + `load`** pattern with **verified** cross-core ordering equivalent to **`TSync_Custom`** in the reference. + +--- + +## 6. Explicit **sync / event** surface in the dialect (parity with `TSync_Custom` / CV FIFO) + +**Problem.** Reference FA uses **`TSync_Custom`**, **`should_wait_consumption` / `should_notify_consumption`**, and optional **CV comm** for backpressure. Python builders today lean on **`--enable-insert-sync`** and pipe **`tfree`** ordering; there is no close 1:1 mapping to **named sync tokens** and **FIFO depth** parameters from `fa_kernel.cpp`. + +**Motivation (ref FA).** Tuning **`QK_PRELOAD`**, **`qkp_tile_fifo_size`**, and **`CV_FIFO_CONS_SYNC_PERIOD`** is central to the C++ launch. + +**Ask.** + +- Expose **optional** `record_event` / `wait_event` (or reuse existing async session ops if applicable) with **stable lowering** to the same primitives reference kernels use, **and/or** +- A **small FA template** in docs that maps **`runTFA` template parameters** → PTO ops + attrs. + +--- + +## 7. Python bindings: **ergonomics** beyond raw `mlir` ODS + +**Problem.** `python/pto/dialects/pto.py` is largely **generated ODS exports**; authors of large kernels still hand-roll **byte offsets**, **`slot_size`**, and **layout** in application code (`ptodsl` or otherwise), which is error-prone when **`S0`**, **`S1_TILE`**, or **`HEAD`** change. + +**Ask.** + +- **Optional** Python helpers (same package or `ptodsl`-side) for: + - **Pipe bundle construction** (`dir_mask`, `slot_size`, `slot_num`, `local_slot_num`) with **static consistency checks**, + - **UB layout** from a declarative map of **tile names → (space, dtype, shape)** with **overlap detection**, + - **“Reference FA preset”** constants: `CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, FIFO depths — emitting the right **`initialize_l2g2l_pipe`** / legacy `*_initialize_pipe` combo. + +--- + +## 8. Documentation: **reference kernel ↔ PTO pipe** mapping + +**Ask.** Add a short chapter to `docs/designs/ptoas-tpush-tpop-design.md` (or a new doc under `docs/designs/`) that shows: + +1. How **`TSTORE(qkGlobalTile, qkAccTile)`** + **`TLOAD(qkVecSub, qkGlobalSub)`** in `fa_kernel.cpp` maps to **`initialize_l2g2l_pipe` + `tpush` + `tpop`** (including **GM stride** / **`kTileFactor`**). +2. Which **`split`** values approximate **`TileSplitAxis::TILE_UP_DOWN`** in the reference P headers. +3. **Known limitations** (e.g. **`TPushOp` producer address spaces** as of current `PTOOps.td`). + +--- + +## Priority (suggested for FA parity) + +| Priority | Item | Why | +|----------|------|-----| +| P0 | **1** (non-ACC `tpush`) + **2** (slot/dtype decouple) | Unblocks **fp16-on-wire** and smaller vec FIFO without losing fp32 matmul. | +| P0 | **3** (K-split / partial QK) | Matches **reference cube + vec geometry**; largest structural mismatch today. | +| P1 | **5** (GM-primary / smaller local ring) | Unlocks **`QK_PRELOAD = 4`**-class schedules without linear growth of vec **`reserve_buffer`**. | +| P1 | **6** (explicit sync) | Needed for **faithful** backpressure / CV parity when scaling blocks/cores. | +| P2 | **4** (richer split) | Reduces pressure on Python to fake **row_slice** with full-height tiles. | +| P2 | **7–8** (bindings + docs) | Reduces integration risk and documents the **intended** lowering contract. | + +--- + +## Related artifacts in this repo + +- Experimental FA builder: `examples/aot/flash_attention/experimental/fa_builder.py` +- Reference C++ kernel: `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` +- Broader gap narrative: `examples/aot/flash_attention/known_gap.md` From 3c71b71417ca00d811939980bdc88de21099830e Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 28 Apr 2026 00:40:18 +0200 Subject: [PATCH 4/7] update ptoas requests with more concrete code --- examples/aot/flash_attention/ptoas_request.md | 302 +++++++++++++++++- 1 file changed, 299 insertions(+), 3 deletions(-) diff --git a/examples/aot/flash_attention/ptoas_request.md b/examples/aot/flash_attention/ptoas_request.md index 039415bf..29db2871 100644 --- a/examples/aot/flash_attention/ptoas_request.md +++ b/examples/aot/flash_attention/ptoas_request.md @@ -2,9 +2,7 @@ This document collects **actionable requests** for the PTOAS / PTO dialect stack so that **flash-attention–style kernels** written in Python (e.g. `examples/aot/flash_attention/experimental/fa_builder.py` via `ptodsl`) can **closely match** the hand-tuned reference in `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA`, `compute_qk`, `compute_p`, `compute_pv`, `compute_gu`). -Upstream sources referenced below live under: - -`.agent/skills/translate_cpp2py/references/external_repo/PTOAS/` +Upstream **PTOAS** sources below use paths **relative to the `PTOAS/` repository root** (same layout as a normal PTOAS checkout), e.g. `include/PTO/IR/PTOOps.td`, `docs/designs/ptoas-tpush-tpop-design.md`. --- @@ -113,6 +111,304 @@ Upstream sources referenced below live under: --- +## Concrete examples (reference C++ ↔ desired Python ↔ today) + +Each subsection ties **one reference pattern** to **what Python would ideally emit**, what **PTOAS / MLIR rejects or cannot express**, and what **`experimental/fa_builder.py` does instead**. + +### A. fp16 payload on the QK cube→vec path (requests **1** and **2**) + +**Reference (GM is fp32; vec uses narrower working tiles and fp16 for P).** Cube stores each **`Cube_S0 × Cube_S1`** QK slice as **float** in `qk_tile_fifo` (not a hardware `TPUSH` from vec’s perspective—MTE `TSTORE` to GM): + +```364:381:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + using GlobalDataQK = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + const uint32_t buf_idx = static_cast(tile_id % QKP_CV_FIFO); + const size_t base_elems = + static_cast(buf_idx) * static_cast(kTileFactor) * static_cast(Cube_S0) * + static_cast(Cube_S1) + + static_cast(sub_tile_id) * static_cast(Cube_S0) * static_cast(Cube_S1); + GlobalDataQK qkGlobalTile(qk_tile_fifo + base_elems); + +#if UF_ENABLE + TSTORE(qkGlobalTile, qkAccTile); +#else + TSTORE(qkGlobalTile, qkAccTile); + set_flag(PIPE_FIX, PIPE_M, accTileEvtID); +#endif + + if (sub_tile_id == static_cast(kTileFactor) - 1) + qk2smSync.record(); // notify for QK produce data +``` + +The **P** pipe uses **`Cube_S0 * Cube_S1 * sizeof(half)`** slots (fp16 on the vec→cube wire): + +```820:823:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + constexpr uint32_t p_tile_fifo_slots = qkp_tile_fifo_size * kTileFactor; + using PPipe = + TPipe; + PPipe pPipe((__gm__ void *)p_tile_fifo_block, 0u, (uint32_t)(uint64_t)pMatTile[0].data()); +``` + +**Desired Python pattern (sketch).** After `tile.matmul(..., qk_acc)`, narrow the C2V **`slot_size`** while keeping matmul in fp32: + +```python +# Ideal: half wire, same logical tile id +tile.cvt(qk_acc, qk_half_tile, rmode="round") # TileBufType(..., dtype=f16, memory_space="MAT"|"LEFT") +pto.tpush(qk_half_tile, qk_pipe, SPLIT_UP_DOWN) +``` + +Vec would `tpop` into **`!pto.tile_buf`** and `tile.cvt` to fp32 before `row_max`. + +**Current failing behavior.** `TPushOp::getPipe()` in upstream PTO only treats **`ACC`** (and **`VEC`**) as having a real pipe id; **`MAT` / `LEFT` / …** fall through to **`PIPE_UNASSIGNED`**, so MLIR verification fails: + +```1767:1793:include/PTO/IR/PTOOps.td + ::mlir::pto::PIPE getPipe() { + auto getAddressSpace = [](Type ty) -> std::optional<::mlir::pto::AddressSpace> { + if (auto tb = ::mlir::dyn_cast<::mlir::pto::TileBufType>(ty)) { + if (auto as = ::mlir::dyn_cast_or_null<::mlir::pto::AddressSpaceAttr>( + tb.getMemorySpace())) + return as.getAddressSpace(); + return std::nullopt; + } + // ... + }; + + auto as = getAddressSpace(getTile().getType()); + if (!as) + return ::mlir::pto::PIPE::PIPE_UNASSIGNED; + if (*as == ::mlir::pto::AddressSpace::ACC) + return ::mlir::pto::PIPE::PIPE_FIX; + if (*as == ::mlir::pto::AddressSpace::VEC) + return ::mlir::pto::PIPE::PIPE_MTE3; + return ::mlir::pto::PIPE::PIPE_UNASSIGNED; + } +``` + +Typical diagnostic: **`'pto.tpush' op tile type must map to a supported producer pipe`**. + +**Un-optimal workaround in `fa_builder.py`.** Keep **`SLOT_SIZE_QK = S0 * S1_TILE * 4`** and push **only** the fp32 accumulator (legal **`ACC`** producer): + +```74:77:examples/aot/flash_attention/experimental/fa_builder.py +# Per-pipe slot sizes (bytes). +# QK: fp32 on the wire (cube `tpush` only accepts ACC tiles today — see known_gap). +SLOT_SIZE_QK = S0 * S1_TILE * 4 +SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator +``` + +```361:362:examples/aot/flash_attention/experimental/fa_builder.py + tile.matmul(q_left, k_right[k], qk_acc[k]) + pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) +``` + +That **doubles** ring bytes versus a half-precision wire format with the same logical geometry. + +--- + +### B. `kTileFactor` / `Vec_S0` vs one big matmul + one `tpop` (requests **3** and **4**) + +**Reference geometry.** `runTFA` fixes **`kTileFactor = Tile_S1 / Cube_S1`** and **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (e.g. **32** row softmax tile height when **`Cube_S0 = 128`**, **`kTileFactor = 2`**): + +```680:691:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + constexpr uint32_t Cube_S0 = CUBE_S0; + uint32_t block_rows = s0 / Cube_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; // per-tile S1 chunk + constexpr uint32_t Tile_S1 = TILE_S1; // logical tile along S1 + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by Cube_S1"); + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; // sub-tiles per TILE_S1 + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor; + constexpr uint32_t VecGuRows = Cube_S0 / VEC_CORES; + static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices"); +``` + +Vec **softmax** tile type is **`Vec_S0 × Tile_S1`**, not **`Cube_S0/2 × Tile_S1`**: + +```747:751:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + // Define tile types for FA softmax P computation + // UB offsets for softmax tiles + // Define per-tile vector tiles sized to Cube_S1 + using TileDataF_T = Tile; + using TileDataH_T = Tile; +``` + +**Reference assembly of the wide QK tile from K-slices in GM** (`compute_p`): two **`TLOAD`**s of **`Vec_S0 × Cube_S1`** into column halves of **`qkVecTile`**: + +```500:517:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + const uint32_t buf_idx = static_cast(tile_id % QKP_CV_FIFO); + const size_t base_elems = static_cast(buf_idx) * static_cast(kTileFactor) * + static_cast(Cube_S0) * static_cast(Cube_S1); + __gm__ float *qk_ptr = qk_tile_fifo + base_elems + row_offset * static_cast(Cube_S1); + + using GlobalDataQK_Sub = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + using TileDataF_Sub = Tile; + for (int sub_col = 0; sub_col < static_cast(kTileFactor); ++sub_col) { + __gm__ float *qk_ptr_sub = + qk_ptr + static_cast(sub_col) * static_cast(Cube_S0) * static_cast(Cube_S1); + GlobalDataQK_Sub qkGlobalSub(qk_ptr_sub); + + TileDataF_Sub qkVecSub; + const uint64_t col_byte_offset = static_cast(sub_col * Cube_S1 * sizeof(float)); + TASSIGN(qkVecSub, (uint64_t)qkVecTile.data() + col_byte_offset); + TLOAD(qkVecSub, qkGlobalSub); + } +``` + +**Desired Python pattern (sketch).** Mirror **`compute_qk`**’s **`sub_tile_id`** loop with **`AccMode`/`InitPartialSum`** semantics, **`slot_size`/`slot_num`** matching **`base_elems`**, and vec **`tpop`** / **`load`** into **`Vec_S0 × S1_TILE`** (or explicit subview column packing) instead of one **`S0_HALF × S1_TILE`** receive tile per hardware half. + +**Current behavior.** The Python builder performs **one** `matmul` over **`HEAD × S1_TILE`** per logical tile and **one** `tpush` of the full **`S0 × S1_TILE`** accumulator; vec uses **`qk_vec_ty` shape `[S0_HALF, S1_TILE]`** with **`TILE_UP_DOWN`**: + +```350:362:examples/aot/flash_attention/experimental/fa_builder.py + for k in range(QK_PRELOAD): + k_off = const(k * S1_TILE) + kt_view_k = pto.slice_view( + kt_sub_ty, + source=tv_k, + offsets=[c0, k_off], + sizes=[cHEAD, cS1_TILE], + ) + pto.load(kt_view_k, k_mat[k]) + tile.mov(k_mat[k], k_right[k]) + tile.matmul(q_left, k_right[k], qk_acc[k]) + pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) +``` + +```550:556:examples/aot/flash_attention/experimental/fa_builder.py + def emit_softmax_step(exp_max_slot, is_init): + qk_recv = pto.tpop( + qk_vec_ty, + qk_pipe, + SPLIT_UP_DOWN, + addr=const(VEC_RECV_OFF, s.int64), + ) +``` + +There is **no** first-class equivalent to **`row_slice` × `sub_col` × `TASSIGN` column packing** in this path. + +**Un-optimal workarounds.** + +- **Omit `kTileFactor` on cube** (single large K tile): simpler schedule but **not** the reference’s **`CUBE_S1 = 128`** matmul shape / partial-sum story. +- **Accept `S0_HALF = S0 // 2` vec rows per `tpop`**: matches **`TILE_UP_DOWN`** hardware split, but **not** the reference’s **`Vec_S0 = Cube_S0 / (2 * kTileFactor)`** (e.g. **64×256** received per subblock vs ref **32×256** working tile). + +--- + +### C. `QK_PRELOAD == 4` and explicit producer/consumer sync (requests **5** and **6**) + +**Reference preload and sync objects.** The launch uses **`qkPreloadNum = QK_PRELOAD`** (template parameter), **`TSync_Custom`** between cube **`TSTORE`** and vec **`TLOAD`**, and nested **`kTileFactor`** loops so cube and vec each run **`kTileFactor`** steps per logical preload tile: + +```817:874:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + constexpr TSync_Custom qk2smSync = {BUF0_QK_READY}; + constexpr TSync_Custom pv2guSync = {UPDATE_READY}; + // ... + for (int preload_tile = 0; preload_tile < static_cast(qkPreloadNum) && preload_tile < num_tiles_s1; + ++preload_tile) { + if constexpr (DAV_CUBE) { + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + qkAccTileEvtID = assign_running_acc_tile(qkAccTile); + compute_qk(preload_tile, sub_tile, q_block, k, qk_tile_fifo_block, + qMatTile[0], kMatTile[k_src_pingpong_id % kMatTNBuffers], + qkAccTile, k_src_pingpong_id % kMatTNBuffers, + qkAccTileEvtID, qk2smSync, block_idx); + k_src_pingpong_id++; + } + } + if constexpr (DAV_VEC) { + for (int row_slice = 0; row_slice < static_cast(kTileFactor); ++row_slice) { + compute_p( + preload_tile, row_slice, qk_tile_fifo_block, exp_max_ififo_block, global_sum_block, exp_max_block, + qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers], x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], + x_expPushT, + input_reduce_tmp, m1_local_max, l1_local_sum, m2_global_max, l2_global_sum, + l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size], triu, p_gu_src_pingpong_id % xexpVecTNBuffers, + qk2smSync, pPipe, block_idx); + p_gu_src_pingpong_id++; + } + } + } +``` + +Inside **`compute_p`**, vec **`qk2smSync.wait()`** / **`free()`** bracket GM visibility: + +```496:520:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + wait_flag(PIPE_V, PIPE_MTE2, pTileEventId); + if (row_slice == 0) + qk2smSync.wait(); // wait for QK produce data + // ... TLOAD assembly ... + if (row_slice == static_cast(kTileFactor) - 1 && should_notify_consume) + qk2smSync.free(); // notify for SM consume data +``` + +**Desired Python pattern.** Set **`QK_PRELOAD = 4`**, model **`l1_exp_max_ififo[qkp_tile_fifo_size]`**, and emit **named sync** (or dialect ops that lower like **`TSync_Custom`**) aligned with **`should_wait_consumption` / `should_notify_consumption`** from the reference. + +**Current behavior / limits.** + +- **`initialize_l2g2l_pipe`** requires **`local_slot_num > 0`** and a **peer `reserve_buffer`**; vec UB scales with **`FIFO_BYTES_QK ≈ slot_size × local_slot_num`**. There is **no** built-in “vec only **`TLOAD`** from `gm_addr` after cube slot closes” mode with **zero** local ring bytes. +- Python FA relies on **`ptoas --enable-insert-sync`** for cross-kernel ordering instead of **`TSync_Custom`**-style explicit tokens. + +**Un-optimal workaround in `fa_builder.py`.** Hard-code **`QK_PRELOAD = 2`** and a **two-tile `exp_max` ring**; document that raising preload needs more UB and/or hazard rework: + +```65:68:examples/aot/flash_attention/experimental/fa_builder.py +# QK preload depth — must be >= 1; reference launch uses 4, this builder +# keeps 2 for a smaller VEC exp_max ring (see header comment). +# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled. +QK_PRELOAD = 2 +``` + +--- + +### D. Manual UB / MAT layout bookkeeping (request **7**) + +**Reference** hides much of this behind **`allocate_cube_tile_buffers` / `allocate_vec_tile_buffers`** templates (`runTFA`). + +**Desired Python / bindings.** Declarative **“tile name → (space, dtype, shape)”** map that **checks overlaps** and derives **`MAT_P_FIFO_OFF`**, vec FIFO bases, and recv scratch **automatically** when `Cube_S0`, `Tile_S1`, or `HEAD` change. + +**Current behavior.** Authors must align **`reserve_buffer`**, **`import_reserved_buffer`**, **`alloc_tile`**, and **`initialize_l2g2l_pipe.slot_size`** by hand. + +**Concrete workarounds in `fa_builder.py` today.** + +- **Pad `MAT_P_FIFO_OFF`** so the P V2C FIFO cannot overlap growing MAT tiles when **`S0`** increases: + +```100:109:examples/aot/flash_attention/experimental/fa_builder.py +# Explicit local-memory layout used when compiling with --pto-level=level3. +# Offsets are byte offsets within each independent local address space. +MAT_Q_OFF = 0 +MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2 +MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2 +MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2 +MAT_P_FIFO_OFF = MAT_V_OFF + S1_TILE * HEAD * 2 +# Pad past the last MAT-resident tile; bisheng is sensitive to overlap here. +if MAT_P_FIFO_OFF < 393216: + MAT_P_FIFO_OFF = 393216 +``` + +- **Shrink reduce-tile ring footprint** with a computed stride instead of a fixed **512** bytes per slot: + +```126:127:examples/aot/flash_attention/experimental/fa_builder.py +# Tight packing for reduce / exp_max ring scalars (one column per logical row). +VEC_RED_STRIDE = ((S0_HALF * 4 + 127) // 128) * 128 +``` + +- **Reuse `p_fp32` as `row_max` scratch** so `tmp_tile` remains free for **`row_sum`**, saving one large vec buffer’s worth of peak live data: + +```557:558:examples/aot/flash_attention/experimental/fa_builder.py + tile.muls(qk_recv, scale, qk_recv) + tile.row_max(qk_recv, p_fp32, local_max) +``` + +- **Single `VEC_RECV_OFF` scratch** sized for the max of half-tile QK (if ever narrowed) and half-tile PV: + +```134:136:examples/aot/flash_attention/experimental/fa_builder.py +# Shared recv scratch: max(fp16 QK half-tile, fp32 PV half-tile) for tpop addr=. +_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 2, S0_HALF * HEAD * 4) +VEC_RECV_OFF = VEC_RED_BASE_OFF + 6 * VEC_RED_STRIDE +``` + +These are **correctness-preserving micro-optimizations**; they do **not** replace dialect support for **§A–§C**. + +--- + ## Priority (suggested for FA parity) | Priority | Item | Why | From c49a45f34e12958085ad058695221f00bab64525 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 28 Apr 2026 01:05:52 +0200 Subject: [PATCH 5/7] remove unreasonable ptoas requests --- examples/aot/flash_attention/known_gap.md | 25 +- examples/aot/flash_attention/ptoas_request.md | 388 +++--------------- 2 files changed, 60 insertions(+), 353 deletions(-) diff --git a/examples/aot/flash_attention/known_gap.md b/examples/aot/flash_attention/known_gap.md index e58e722b..8b965fab 100644 --- a/examples/aot/flash_attention/known_gap.md +++ b/examples/aot/flash_attention/known_gap.md @@ -29,7 +29,7 @@ The reference’s hot path is not arbitrary `tile.*` soup; it goes through share | Reference include | Role | |-------------------|------| | [`pto_macro_matmul.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_matmul.hpp) | Cube matmul with **`AccMode`** (e.g. `InitFinalSum` under `UF_ENABLE`), K tiling, and L0-oriented constraints. | -| [`pto_macro_fa_softmax.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_softmax.hpp) | Streaming softmax: `softmax_opt_fa_init_impl` / `softmax_opt_fa_not_init_impl`, scale = `1/sqrt(HEAD)`, **TROWMAX**, **TROWEXPANDSUB**, **TEXP**, **TROWSUM**, **reshape + TCVT** to fp16, causal branches where applicable. | +| [`pto_macro_fa_softmax.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_softmax.hpp) | Streaming softmax: `softmax_opt_fa_init_impl` / `softmax_opt_fa_not_init_impl`, scale = `1/sqrt(HEAD)`, **TROWMAX**, **TROWEXPANDSUB**, **TEXP**, **TROWSUM**, and (inside the macro) the sequence that feeds **P**—including **TCVT** where the **macro** emits half for the V2C pipe. **QK** in the reference kernel is still **fp32 in GM** via **`TSTORE`** and **`TLOAD`** in `compute_qk` / `compute_p`; do not conflate macro-internal P conversion with a separate invented “fp16 QK wire”. | | [`pto_macro_fa_gu.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_gu.hpp) | **pto_macro_fa_gu** (`TROWEXPANDMUL` + `TADD`), **pto_macro_fa_gu_last** (+ `TROWEXPANDDIV` by `new_global_sum`), **pto_macro_fa_gu_single_and_last_tile**. | **Goal:** Express the same ordered primitive sequence (and the same init vs non-init / last-tile branching) in `ptodsl` `tile.*` / `pto.*` APIs—or add thin DSL helpers that document a 1:1 mapping to those macros—so the compiler stack can match the reference kernel’s numerics and fusion expectations. Today’s DSL uses composable `tile.row_max`, `tile.exp`, etc.; they must be **audited and aligned** macro-step by macro-step, not assumed equivalent. @@ -60,28 +60,15 @@ Measured on NPU via `experimental/run.py` (Q=2048, H=128, S1_TILE=256): kernel h **Takeaway:** With **`S0=128`** landed in the experimental builder, the largest remaining structural gaps versus the reference are **`kTileFactor` / `CUBE_S1` K-split**, **preload / ring depth (`QK_PRELOAD`, CV FIFO)**, and **vec working-tile geometry (`Vec_S0`)** relative to what `l2g2l_pipe` + `TILE_UP_DOWN` imply for UB. Raising `QK_PRELOAD` still needs more `exp_max` slots and recv/GM budget. -| `S0=128` (Apr 2026) | Default `S0` raised to **128** (`FA_S0`). `pto.tpush` from cube still requires **ACC** tiles (PTO verifier: only `AddressSpace::ACC` maps to a producer pipe); staging QK as **fp16 on MAT/LEFT** before push was rejected at MLIR verify, so QK stays **fp32 on the wire** with full `SLOT_SIZE_QK`. Vec softmax reuses **`p_fp32` as `row_max` scratch** (same lifetime as before `row_expand_sub`) plus a **single shared `VEC_RECV_OFF`** sized for the larger of QK/PV half-tiles. `experimental/run.py` + `compile.sh` pass on NPU at ~24 TFLOP/s (unchanged order-of-magnitude vs fused ref). | +| `S0=128` (Apr 2026) | Default `S0` raised to **128** (`FA_S0`). Cube **`pto.tpush`** uses **`ACC`** QK tiles into `l2g2l_pipe` (same as today’s supported producer); the **reference** path is **`TSTORE`** from acc to **fp32 GM**, not MAT/LEFT→GM—**`PIPE_UNASSIGNED` for MAT/LEFT `tpush` is expected**, not a toolchain defect to “fix” for FA. Builder keeps **fp32 `SLOT_SIZE_QK`**, vec **`p_fp32` as `row_max` scratch**, and a **shared `VEC_RECV_OFF`** for half-tile **`tpop`** scratch. `experimental/run.py` + `compile.sh` pass on NPU at ~24 TFLOP/s (order-of-magnitude below fused ref). | --- -## PTOAS / PTO dialect / Python binding — feature requests (algorithm parity) +## PTOAS / PTO dialect / Python binding — reasonable asks (see `ptoas_request.md`) -These are the main **toolchain** gaps noticed while aligning `experimental/fa_kernel` with `cpp_ref/naive_tpush/fa_kernel.cpp`. They are not criticisms of the hand-written reference; they are concrete asks so the **same algorithm config** (tiling, dtypes on wires, vec working set) can be expressed without fighting verifiers or UB. +Feature requests must **mirror what `fa_kernel.cpp` already does**, not invent paths the reference does not use (e.g. **no** MAT/LEFT→GM for QK, **no** extra **`tile.cvt`** pipeline on QK beyond what the **macros** imply for **P**). FP32 QK in the ref is **`TSTORE`/`TLOAD`** to/from **fp32 GM**; **`ptoas --enable-insert-sync`** remains the **intentional** allowed divergence from hand-placed `TSync_Custom`. -1. **C2V `pto.tpush` producer tiles beyond ACC** - Today `TPushOp::getPipe()` maps **only** `AddressSpace::ACC` → `PIPE_FIX` (see `PTOOps.td`); **MAT** and **LEFT** producers yield `PIPE_UNASSIGNED` and fail verification. The reference keeps QK in **fp32 in GM** (`qk_tile_fifo`) and uses **fp16** only inside vec macros (`TileDataH_T`, `TCVT`). A natural DSL port would **cvt** `TileAcc` → **`Tile`** and `tpush` that tile to halve **`slot_size`** / vec FIFO pressure. **Ask:** allow **fp16 (and/or LEFT/MAT) tiles** as legal C2V `tpush` sources when `slot_size` matches, or document the intended lowering (e.g. MTE path) so Python does not need ACC-only staging. - -2. **Decouple `slot_size` from “one full cube row tile” for vec UB accounting** - Reference **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (e.g. **32** rows × **256** cols in vec UB) while GM still holds **`Cube_S0 × Tile_S1`** floats per logical tile, assembled from **`kTileFactor`** slices of **`Cube_S0 × Cube_S1`**. The DSL **`l2g2l_pipe`** ties **vec `reserve_buffer`** size to **`SLOT_SIZE_QK`** and **`tpop`** delivers **`S0_HALF × S1_TILE`** per subblock. **Ask:** first-class **“logical tile vs wire chunk”** (multi-slot per tile_id, or column-strip `tpop` into a fixed vec workspace) so vec UB tracks **`Vec_S0`** like the C++ launch, not **`Cube_S0/2`** per `TILE_UP_DOWN` alone. - -3. **`kTileFactor` / K-split + softmax without a single 64×256 vec tile** - Matching the reference requires **multiple `compute_p` / `row_slice` passes** per tile and **partial QK layout in GM** (`base_elems + row_offset * Cube_S1`). **Ask:** DSL helpers or ops for **GM strided views** + **event sync** equivalent to `TSync_Custom` / `qk2smSync`, or **documented** mapping from `initialize_l2g2l_pipe` + `tpop` to that pattern so cube can emit **128×128** stores while vec runs **32×256** softmax without holding a **64×256** `qk_vec` buffer per subblock. - -4. **`QK_PRELOAD = 4` and deeper CV FIFOs** - Reference uses **`qkPreloadNum = 4`** with **`l1_exp_max_ififo[qkp_tile_fifo_size]`**. DSL stays at **`QK_PRELOAD = 2`** for a smaller **`exp_max` ring**. **Ask:** either **lowered UB cost** for pipe rings (item 1–2) or **optional GM-backed vec inputs** so preload depth can match the C++ launch without manual byte arithmetic. - -5. **Python binding ergonomics** - **Ask:** optional **computed layout** (or static asserts) from tensor shapes for **MAT / VEC base offsets** so raising `S0` cannot silently overlap **`MAT_P_FIFO`** with cube tiles; and a **single knob** mirroring `runTFA` template parameters (`CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, CV FIFO depth) mapped to **`S0`**, **`S1_TILE`**, **`QK_PRELOAD`**, and pipe **`slot_num` / `local_slot_num`**. +A maintained list of **documentation-first** and **parity** asks (GM packing vs `l2g2l_pipe`, `kTileFactor`/`Vec_S0`, sync equivalence, ergonomics) lives in **`examples/aot/flash_attention/ptoas_request.md`**. --- @@ -104,7 +91,7 @@ Use this as a work backlog; order roughly reflects suggested priority (tiling/bu ### Macro parity in Python (lowering contract) - [ ] **Matmul:** Port or wrap **`pto_macro_matmul`** semantics in DSL—especially **`AccMode`** (`Init` / `InitFinalSum` / partial vs final slices) and the **K-subslice** interaction with double-buffered K tiles. -- [ ] **Softmax:** Port **`softmax_opt_fa_init_impl`** and **`softmax_opt_fa_not_init_impl`** (and causal paths if needed) as an explicit sequence of DSL ops matching **TROWMAX → TROWEXPANDSUB → scale → TEXP → TROWSUM → reshape/TCVT** behavior in `pto_macro_fa_softmax.hpp`. +- [ ] **Softmax:** Port **`softmax_opt_fa_init_impl`** and **`softmax_opt_fa_not_init_impl`** (and causal paths if needed) as an explicit sequence of DSL ops matching the **macro** ordering in `pto_macro_fa_softmax.hpp` (including whatever the macro uses to produce **half** for **P**—e.g. **`TCVT`** inside the macro—not a separate invented QK cast). - [ ] **GU:** Port **`pto_macro_fa_gu`**, **`pto_macro_fa_gu_last`**, and **`pto_macro_fa_gu_single_and_last_tile`** as DSL sequences matching **TROWEXPANDMUL / TADD / TROWEXPANDDIV** usage in `pto_macro_fa_gu.hpp`. - [ ] **Numerics test:** Keep **`torch.testing.assert_close`** (same `rtol`/`atol` as `run.py` / `experimental/run.py`) as the gate after each macro block port. diff --git a/examples/aot/flash_attention/ptoas_request.md b/examples/aot/flash_attention/ptoas_request.md index 29db2871..4ee89700 100644 --- a/examples/aot/flash_attention/ptoas_request.md +++ b/examples/aot/flash_attention/ptoas_request.md @@ -1,125 +1,90 @@ # PTOAS feature requests (PTO MLIR dialect + Python bindings) -This document collects **actionable requests** for the PTOAS / PTO dialect stack so that **flash-attention–style kernels** written in Python (e.g. `examples/aot/flash_attention/experimental/fa_builder.py` via `ptodsl`) can **closely match** the hand-tuned reference in `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA`, `compute_qk`, `compute_p`, `compute_pv`, `compute_gu`). +This document lists **reasonable** asks for PTOAS / the PTO dialect so Python FA (`experimental/fa_builder.py` via `ptodsl`) can **mirror** `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA`, `compute_qk`, `compute_p`, `compute_pv`, `compute_gu`). -Upstream **PTOAS** sources below use paths **relative to the `PTOAS/` repository root** (same layout as a normal PTOAS checkout), e.g. `include/PTO/IR/PTOOps.td`, `docs/designs/ptoas-tpush-tpop-design.md`. +**Ground rules (read first)** ---- - -## 1. Allow cube-side `pto.tpush` from non-ACC tiles (C2V producer coverage) - -**Problem.** `TPushOp::getPipe()` only maps **`AddressSpace::ACC`** to a concrete pipe (`PIPE_FIX`). **`MAT`** and **`LEFT`** tiles map to **`PIPE_UNASSIGNED`**, so MLIR verification fails with *“tile type must map to a supported producer pipe”* when attempting to push an fp16 staging tile (e.g. post-`TCvt` from acc) over a C2V pipe. - -**Evidence.** `include/PTO/IR/PTOOps.td`, `TPushOp` `getPipe()` (lines ~1767–1792): only `ACC` and `VEC` branches; all other address spaces return `PIPE_UNASSIGNED`. - -**Motivation (ref FA).** The C++ reference keeps **fp32 QK in GM** and uses **fp16** inside vec macros for softmax output / P staging. A Python port naturally wants **cvt(acc f32 → mat/left f16) → tpush** to **halve `slot_size`** and vec FIFO pressure while keeping matmul in fp32. - -**Ask.** +1. **No invented algorithms.** PTOAS should support a **1:1** mapping to the reference’s data path and control flow, not new “more clever” patterns the C++ kernel does not use. +2. **QK path in the reference is not `TCvt` on tiles then `TPUSH` from MAT/LEFT.** Cube writes QK to GM with **`TSTORE`** from the **accumulator** (`compute_qk`); vec reads with **`TLOAD`** into vec UB (`compute_p`). There is **no** MAT/LEFT→GM push for QK. Where fp16 appears for **P**, the reference uses the **V2C `TPipe`** / macro path (e.g. `sizeof(half)` slot size in `runTFA`), not a fabricated “fp16 QK wire”. +3. **`TPushOp` accepts ACC (and VEC for the other direction) by design** (`include/PTO/IR/PTOOps.td`). **`PIPE_UNASSIGNED` for MAT/LEFT as `tpush` sources is expected**: there is no direct MAT/LEFT→global path like `TSTORE`. That is **not** a bug to “fix” by widening `tpush` to MAT/LEFT for FA—doing so would **diverge** from the reference. +4. **Allowed toolchain divergence** from hand-written C++: **`ptoas --enable-insert-sync`** (and similar passes) that insert synchronization; everything else should aim at **reference parity**, not new semantics. -- Extend **`TPushOp`** (and verifier / lowering to EmitC) so **cube producers** can legally push **`TileBufType` in `MAT` and/or `LEFT`** with dtypes compatible with the pipe’s `slot_size`, **or** -- Document and implement an **official lowering path** (e.g. implicit MTE move acc→staging then push) so frontends do not need to guess unsupported combinations. +Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. --- -## 2. Decouple `slot_size` (wire bytes) from producer/consumer tile element type - -**Problem.** `initialize_l2g2l_pipe` takes a single **`slot_size` (bytes)** while `tpush`/`tpop` tile types carry **dtype + shape**. Today authors must keep **manual consistency** between `SLOT_SIZE_QK`, cube `TileAcc`, and vec `Tile`; there is no first-class “**fp32 compute, fp16 wire**” contract. +## 1. Documentation: reference **QK** path ↔ **`l2g2l_pipe` + ACC `tpush` / `tpop`** -**Evidence.** `InitializeL2G2LPipeOp` in `PTOOps.td` (~1681–1712): `slot_size` is a plain `i32`; pipe init does not encode logical vs physical width. +**What the reference does.** Cube **`TSTORE`** of **`float`** tiles shaped **`Cube_S0 × Cube_S1`** into `qk_tile_fifo` with the `base_elems` packing (`fa_kernel.cpp`, `compute_qk`). Vec **`TLOAD`**s **`Vec_S0 × Cube_S1`** slices and assembles **`Vec_S0 × Tile_S1`** in UB (`compute_p`, loop over `sub_col`). -**Motivation (ref FA).** Reference layout uses **`sizeof(float)` × Cube_S0 × Tile_S1`** in GM for `qk_tile_fifo`, while vec tiles are **`Vec_S0 × Tile_S1`** with **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`**. The toolchain should help express **logical tile**, **wire format**, and **vec working tile** without ad-hoc byte math in Python. +**What the Python builder does today.** Uses **`pto.tpush(qk_acc, qk_pipe)`** with **`slot_size = S0 * S1_TILE * sizeof(fp32)`** and GM backing—**analogous** to getting QK to GM + vec visibility, but **not** identical to the reference’s per-`sub_tile_id` **`Cube_S0×Cube_S1`** store layout when `kTileFactor > 1`. -**Ask.** +**Ask (PTOAS / docs only).** -- Optional attributes on **`initialize_l2g2l_pipe`** (or companion op) for **`wire_elem_type`**, **`logical_shape`**, and/or **`vec_slice_shape`**, validated against `slot_size`, **or** -- A small **tablegen-verified** bundle type for “pipe slot descriptor” consumed by both cube and vec builders. +- In `docs/designs/ptoas-tpush-tpop-design.md` (or a small FA note), add a **side-by-side**: reference **`TSTORE`/`TLOAD` + `base_elems`** vs recommended **`slot_size` / `slot_num` / GM pointer math** for `initialize_l2g2l_pipe` so Python authors can **match the reference packing** without guessing. +- Clarify explicitly: **cube `tpush` from ACC** is the supported producer; **do not** document MAT/LEFT `tpush` to GM as a FA workaround—that path **does not exist** in the reference. --- -## 3. First-class **K-split** (`kTileFactor`) and **partial QK** delivery to vec +## 2. **`kTileFactor` / `Cube_S1` K-split and vec `Vec_S0` (reference-only geometry)** -**Problem.** The reference runs **`kTileFactor = Tile_S1 / Cube_S1`** cube passes (e.g. two **128×128** matmuls per **256**-wide logical tile), stores **`Cube_S0 × Cube_S1`** slices into GM, and vec **`compute_p`** performs **`kTileFactor`** **TLOAD**s of **`Vec_S0 × Cube_S1`** into a **`Vec_S0 × Tile_S1`** vec tile. The Python + `l2g2l_pipe` path instead tends toward **one full `Cube_S0 × Tile_S1` tpush** and a **`S0_HALF × S1_TILE` tpop**, which inflates **vec UB** versus **`Vec_S0 × Tile_S1`**. +**What the reference does.** `kTileFactor = Tile_S1 / Cube_S1`; two matmul passes per logical tile; vec softmax tile **`Vec_S0 × Tile_S1`** with **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (`runTFA`); **`compute_p`** loads **`kTileFactor`** column strips from GM. -**Motivation (ref FA).** Matching **`CUBE_S1`**, **`kTileFactor`**, and **`Vec_S0`** is required for both **numerics/scheduling parity** and **UB parity** with `fa_kernel.cpp`. +**What the Python builder does today.** One **`HEAD × S1_TILE`** matmul and one full **`S0 × S1_TILE`** `tpush`; vec **`tpop`** **`S0_HALF × S1_TILE`** with **`TILE_UP_DOWN`**. **Ask.** -- Either **documented** lowering from “ref-style GM layout + sync” to **`initialize_l2g2l_pipe` + `tpush`/`tpop`**, **or** new ops / pipe modes for: - - **multiple ordered `tpush`es** per logical `tile_id` with **fixed GM packing** matching the reference’s `base_elems` formulas, and - - **vec-side assembly** (`tpop` into column sub-ranges of one vec tile, or explicit `tassign`/`subview` at UB addresses) without requiring a single oversized **`tpop`** result tile. +- **Documented** recipes (and, if useful, **ptodsl** helpers only—no new PTO ops required) for: cube **`AccMode`/`InitPartialSum`/`AccPartialSum`**-style sequences matching `compute_qk` + `compute_pv`, and vec **`pto.load`** / **`slice_view`** patterns matching **`TLOAD`** + **`TASSIGN`** column offsets in `compute_p`. +- If something in **MLIR verification** blocks a **literal** ref-shaped **`pto.store`**/`load` schedule that is otherwise valid, file that as a **narrow bugfix** with a ref citation—not a new feature. --- -## 4. Richer **`split`** / subblock model (beyond one `TILE_UP_DOWN` halving) +## 3. **Software row / subblock indexing (`row_slice`, `get_subblockid`)** -**Problem.** `split` on `tpush`/`tpop` models a **single** split axis enum; reference logic combines **`get_subblockid()`**, **`row_slice`**, and **`kTileFactor`** to address **four** distinct **32-row** bands across **`Cube_S0 = 128`**. Expressing that with only **one** up/down split per op forces **larger per-core vec tiles** than the reference. - -**Evidence.** Design notes in `docs/designs/ptoas-tpush-tpop-design.md` (split semantics); reference `compute_p` row/col slicing in `fa_kernel.cpp`. +**What the reference does.** `row_offset = subblock_base_rows + row_slice * Vec_S0` and reduce-tile **`TASSIGN`** byte offsets (`compute_p`)—**software** decomposition, not a request for new hardware **`split`** enum values unless the ISA already exposes them for the same pattern. **Ask.** -- Consider **documented composition** of splits (e.g. nested phases) **or** additional split modes / **multi-phase tpop** that align with **`row_slice × subblock`** patterns used in FA macros. +- **Examples in docs** showing how to express the same indexing with **existing** Python/PTO constructs (`get_subblock_idx`, scalar offsets, multiple `load`s), so authors do not conflate **`TILE_UP_DOWN`** alone with the reference’s **`row_slice × kTileFactor`** schedule. --- -## 5. **`local_slot_num` / vec `reserve_buffer`** vs GM-only consumer patterns +## 4. **Sync: `TSync_Custom`, CV FIFO depth, `QK_PRELOAD` (exists in reference)** -**Problem.** `local_slot_num` must be **> 0** and `local_addr` is mandatory for `initialize_l2g2l_pipe` (verifier in `PTO.cpp` / design doc §5.2). The reference often behaves like **“cube writes GM; vec reads GM after sync”** with **smaller vec-local FIFOs** (`srcVecTNBuffers`, etc.), not necessarily a full **local mirror** of every slot byte in UB. +**What the reference does.** `TSync_Custom<…>` around QK produce / vec consume (`qk2smSync`); `should_wait_consumption` / `should_notify_consumption`; template **`QK_PRELOAD`**, **`qkp_tile_fifo_size`**, **`CV_FIFO_CONS_SYNC_PERIOD`** (`runTFA`, `compute_p`, `compute_qk`). -**Evidence.** `docs/designs/ptoas-tpush-tpop-design.md` (~318–361, ~759–761). +**What the Python stack does today.** **`--enable-insert-sync`** plus pipe **`tfree`** / implicit ordering—**intentionally** different from hand-placed `TSync_Custom`. -**Ask.** +**Ask (reasonable).** -- Optional **GM-primary consumer** mode: vec **`tpop`** semantics that **do not** require **`reserve_buffer(slot_size × local_slot_num)`** when the consumer only needs a **bounded scratch** (with **verified** max live bytes), **or** -- A **`tpop_from_gm` / `wait_slot` + `load`** pattern with **verified** cross-core ordering equivalent to **`TSync_Custom`** in the reference. +- **Optional** dialect or pass hooks that lower to the **same** sync primitives the reference uses, **or** a documented **equivalence table**: “ref `qk2smSync.wait()` ↔ inserted barrier X after `ptoas` version Y”. +- This is **not** a license to invent new memory paths; it is **parity** for **control** the reference already has. --- -## 6. Explicit **sync / event** surface in the dialect (parity with `TSync_Custom` / CV FIFO) - -**Problem.** Reference FA uses **`TSync_Custom`**, **`should_wait_consumption` / `should_notify_consumption`**, and optional **CV comm** for backpressure. Python builders today lean on **`--enable-insert-sync`** and pipe **`tfree`** ordering; there is no close 1:1 mapping to **named sync tokens** and **FIFO depth** parameters from `fa_kernel.cpp`. +## 5. **Python / `ptodsl` ergonomics (optional; ref-shaped constants)** -**Motivation (ref FA).** Tuning **`QK_PRELOAD`**, **`qkp_tile_fifo_size`**, and **`CV_FIFO_CONS_SYNC_PERIOD`** is central to the C++ launch. +**Problem.** Layout math (`GM_*_OFF_F32`, `MAT_*_OFF`, vec FIFO bytes) is hand-rolled and easy to break when `Cube_S0` / `Tile_S1` / `HEAD` change—**the reference avoids some of this** with template parameters and allocator helpers. **Ask.** -- Expose **optional** `record_event` / `wait_event` (or reuse existing async session ops if applicable) with **stable lowering** to the same primitives reference kernels use, **and/or** -- A **small FA template** in docs that maps **`runTFA` template parameters** → PTO ops + attrs. +- **Optional** helpers or tables driven only by **`runTFA`-style constants** (`CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, FIFO sizes) to generate **GM strides matching `base_elems`**, **MAT/VEC base offsets**, and **static overlap checks**—without introducing new runtime algorithms. --- -## 7. Python bindings: **ergonomics** beyond raw `mlir` ODS +## 6. **Documentation cross-link (macros)** -**Problem.** `python/pto/dialects/pto.py` is largely **generated ODS exports**; authors of large kernels still hand-roll **byte offsets**, **`slot_size`**, and **layout** in application code (`ptodsl` or otherwise), which is error-prone when **`S0`**, **`S1_TILE`**, or **`HEAD`** change. - -**Ask.** - -- **Optional** Python helpers (same package or `ptodsl`-side) for: - - **Pipe bundle construction** (`dir_mask`, `slot_size`, `slot_num`, `local_slot_num`) with **static consistency checks**, - - **UB layout** from a declarative map of **tile names → (space, dtype, shape)** with **overlap detection**, - - **“Reference FA preset”** constants: `CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, FIFO depths — emitting the right **`initialize_l2g2l_pipe`** / legacy `*_initialize_pipe` combo. +**Ask.** In PTOAS or `ptodsl` docs, link **`pto_macro_matmul` / `pto_macro_fa_softmax` / `pto_macro_fa_gu`** sequences to the **minimal** `tile.*` / `pto.*` sequences needed for lowering parity—**macro-internal** ops (including whatever the macro uses for P) are **in scope**; **invented** pre-pipeline **`tile.cvt`** on QK to fake a dtype the reference does not use on that path is **out of scope**. --- -## 8. Documentation: **reference kernel ↔ PTO pipe** mapping +## Concrete examples (reference ↔ Python today) -**Ask.** Add a short chapter to `docs/designs/ptoas-tpush-tpop-design.md` (or a new doc under `docs/designs/`) that shows: - -1. How **`TSTORE(qkGlobalTile, qkAccTile)`** + **`TLOAD(qkVecSub, qkGlobalSub)`** in `fa_kernel.cpp` maps to **`initialize_l2g2l_pipe` + `tpush` + `tpop`** (including **GM stride** / **`kTileFactor`**). -2. Which **`split`** values approximate **`TileSplitAxis::TILE_UP_DOWN`** in the reference P headers. -3. **Known limitations** (e.g. **`TPushOp` producer address spaces** as of current `PTOOps.td`). - ---- +### A. QK: **`TSTORE`/`TLOAD`** (ref) vs **`tpush`/`tpop`** + ACC (Python) -## Concrete examples (reference C++ ↔ desired Python ↔ today) +**Reference — cube writes fp32 QK slices to GM** (`examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp`): -Each subsection ties **one reference pattern** to **what Python would ideally emit**, what **PTOAS / MLIR rejects or cannot express**, and what **`experimental/fa_builder.py` does instead**. - -### A. fp16 payload on the QK cube→vec path (requests **1** and **2**) - -**Reference (GM is fp32; vec uses narrower working tiles and fp16 for P).** Cube stores each **`Cube_S0 × Cube_S1`** QK slice as **float** in `qk_tile_fifo` (not a hardware `TPUSH` from vec’s perspective—MTE `TSTORE` to GM): - -```364:381:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp +```364:376:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp using GlobalDataQK = GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; const uint32_t buf_idx = static_cast(tile_id % QKP_CV_FIFO); @@ -128,120 +93,12 @@ Each subsection ties **one reference pattern** to **what Python would ideally em static_cast(Cube_S1) + static_cast(sub_tile_id) * static_cast(Cube_S0) * static_cast(Cube_S1); GlobalDataQK qkGlobalTile(qk_tile_fifo + base_elems); - -#if UF_ENABLE - TSTORE(qkGlobalTile, qkAccTile); -#else TSTORE(qkGlobalTile, qkAccTile); - set_flag(PIPE_FIX, PIPE_M, accTileEvtID); -#endif - - if (sub_tile_id == static_cast(kTileFactor) - 1) - qk2smSync.record(); // notify for QK produce data ``` -The **P** pipe uses **`Cube_S0 * Cube_S1 * sizeof(half)`** slots (fp16 on the vec→cube wire): - -```820:823:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp - constexpr uint32_t p_tile_fifo_slots = qkp_tile_fifo_size * kTileFactor; - using PPipe = - TPipe; - PPipe pPipe((__gm__ void *)p_tile_fifo_block, 0u, (uint32_t)(uint64_t)pMatTile[0].data()); -``` - -**Desired Python pattern (sketch).** After `tile.matmul(..., qk_acc)`, narrow the C2V **`slot_size`** while keeping matmul in fp32: - -```python -# Ideal: half wire, same logical tile id -tile.cvt(qk_acc, qk_half_tile, rmode="round") # TileBufType(..., dtype=f16, memory_space="MAT"|"LEFT") -pto.tpush(qk_half_tile, qk_pipe, SPLIT_UP_DOWN) -``` - -Vec would `tpop` into **`!pto.tile_buf`** and `tile.cvt` to fp32 before `row_max`. - -**Current failing behavior.** `TPushOp::getPipe()` in upstream PTO only treats **`ACC`** (and **`VEC`**) as having a real pipe id; **`MAT` / `LEFT` / …** fall through to **`PIPE_UNASSIGNED`**, so MLIR verification fails: - -```1767:1793:include/PTO/IR/PTOOps.td - ::mlir::pto::PIPE getPipe() { - auto getAddressSpace = [](Type ty) -> std::optional<::mlir::pto::AddressSpace> { - if (auto tb = ::mlir::dyn_cast<::mlir::pto::TileBufType>(ty)) { - if (auto as = ::mlir::dyn_cast_or_null<::mlir::pto::AddressSpaceAttr>( - tb.getMemorySpace())) - return as.getAddressSpace(); - return std::nullopt; - } - // ... - }; - - auto as = getAddressSpace(getTile().getType()); - if (!as) - return ::mlir::pto::PIPE::PIPE_UNASSIGNED; - if (*as == ::mlir::pto::AddressSpace::ACC) - return ::mlir::pto::PIPE::PIPE_FIX; - if (*as == ::mlir::pto::AddressSpace::VEC) - return ::mlir::pto::PIPE::PIPE_MTE3; - return ::mlir::pto::PIPE::PIPE_UNASSIGNED; - } -``` - -Typical diagnostic: **`'pto.tpush' op tile type must map to a supported producer pipe`**. - -**Un-optimal workaround in `fa_builder.py`.** Keep **`SLOT_SIZE_QK = S0 * S1_TILE * 4`** and push **only** the fp32 accumulator (legal **`ACC`** producer): - -```74:77:examples/aot/flash_attention/experimental/fa_builder.py -# Per-pipe slot sizes (bytes). -# QK: fp32 on the wire (cube `tpush` only accepts ACC tiles today — see known_gap). -SLOT_SIZE_QK = S0 * S1_TILE * 4 -SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator -``` - -```361:362:examples/aot/flash_attention/experimental/fa_builder.py - tile.matmul(q_left, k_right[k], qk_acc[k]) - pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) -``` - -That **doubles** ring bytes versus a half-precision wire format with the same logical geometry. - ---- +**Reference — vec reads fp32 from GM into column strips of `qkVecTile`** (same file, `compute_p`): -### B. `kTileFactor` / `Vec_S0` vs one big matmul + one `tpop` (requests **3** and **4**) - -**Reference geometry.** `runTFA` fixes **`kTileFactor = Tile_S1 / Cube_S1`** and **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (e.g. **32** row softmax tile height when **`Cube_S0 = 128`**, **`kTileFactor = 2`**): - -```680:691:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp - constexpr uint32_t Cube_S0 = CUBE_S0; - uint32_t block_rows = s0 / Cube_S0; - constexpr uint32_t Cube_S1 = CUBE_S1; // per-tile S1 chunk - constexpr uint32_t Tile_S1 = TILE_S1; // logical tile along S1 - static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by Cube_S1"); - constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; // sub-tiles per TILE_S1 - constexpr uint32_t Cube_HEAD = HEAD_SIZE; - constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor; - constexpr uint32_t VecGuRows = Cube_S0 / VEC_CORES; - static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices"); -``` - -Vec **softmax** tile type is **`Vec_S0 × Tile_S1`**, not **`Cube_S0/2 × Tile_S1`**: - -```747:751:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp - // Define tile types for FA softmax P computation - // UB offsets for softmax tiles - // Define per-tile vector tiles sized to Cube_S1 - using TileDataF_T = Tile; - using TileDataH_T = Tile; -``` - -**Reference assembly of the wide QK tile from K-slices in GM** (`compute_p`): two **`TLOAD`**s of **`Vec_S0 × Cube_S1`** into column halves of **`qkVecTile`**: - -```500:517:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp - const uint32_t buf_idx = static_cast(tile_id % QKP_CV_FIFO); - const size_t base_elems = static_cast(buf_idx) * static_cast(kTileFactor) * - static_cast(Cube_S0) * static_cast(Cube_S1); - __gm__ float *qk_ptr = qk_tile_fifo + base_elems + row_offset * static_cast(Cube_S1); - - using GlobalDataQK_Sub = - GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; - using TileDataF_Sub = Tile; +```505:517:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp for (int sub_col = 0; sub_col < static_cast(kTileFactor); ++sub_col) { __gm__ float *qk_ptr_sub = qk_ptr + static_cast(sub_col) * static_cast(Cube_S0) * static_cast(Cube_S1); @@ -254,171 +111,34 @@ Vec **softmax** tile type is **`Vec_S0 × Tile_S1`**, not **`Cube_S0/2 × Tile_S } ``` -**Desired Python pattern (sketch).** Mirror **`compute_qk`**’s **`sub_tile_id`** loop with **`AccMode`/`InitPartialSum`** semantics, **`slot_size`/`slot_num`** matching **`base_elems`**, and vec **`tpop`** / **`load`** into **`Vec_S0 × S1_TILE`** (or explicit subview column packing) instead of one **`S0_HALF × S1_TILE`** receive tile per hardware half. - -**Current behavior.** The Python builder performs **one** `matmul` over **`HEAD × S1_TILE`** per logical tile and **one** `tpush` of the full **`S0 × S1_TILE`** accumulator; vec uses **`qk_vec_ty` shape `[S0_HALF, S1_TILE]`** with **`TILE_UP_DOWN`**: - -```350:362:examples/aot/flash_attention/experimental/fa_builder.py - for k in range(QK_PRELOAD): - k_off = const(k * S1_TILE) - kt_view_k = pto.slice_view( - kt_sub_ty, - source=tv_k, - offsets=[c0, k_off], - sizes=[cHEAD, cS1_TILE], - ) - pto.load(kt_view_k, k_mat[k]) - tile.mov(k_mat[k], k_right[k]) - tile.matmul(q_left, k_right[k], qk_acc[k]) - pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) -``` - -```550:556:examples/aot/flash_attention/experimental/fa_builder.py - def emit_softmax_step(exp_max_slot, is_init): - qk_recv = pto.tpop( - qk_vec_ty, - qk_pipe, - SPLIT_UP_DOWN, - addr=const(VEC_RECV_OFF, s.int64), - ) -``` - -There is **no** first-class equivalent to **`row_slice` × `sub_col` × `TASSIGN` column packing** in this path. - -**Un-optimal workarounds.** - -- **Omit `kTileFactor` on cube** (single large K tile): simpler schedule but **not** the reference’s **`CUBE_S1 = 128`** matmul shape / partial-sum story. -- **Accept `S0_HALF = S0 // 2` vec rows per `tpop`**: matches **`TILE_UP_DOWN`** hardware split, but **not** the reference’s **`Vec_S0 = Cube_S0 / (2 * kTileFactor)`** (e.g. **64×256** received per subblock vs ref **32×256** working tile). - ---- - -### C. `QK_PRELOAD == 4` and explicit producer/consumer sync (requests **5** and **6**) - -**Reference preload and sync objects.** The launch uses **`qkPreloadNum = QK_PRELOAD`** (template parameter), **`TSync_Custom`** between cube **`TSTORE`** and vec **`TLOAD`**, and nested **`kTileFactor`** loops so cube and vec each run **`kTileFactor`** steps per logical preload tile: - -```817:874:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp - constexpr TSync_Custom qk2smSync = {BUF0_QK_READY}; - constexpr TSync_Custom pv2guSync = {UPDATE_READY}; - // ... - for (int preload_tile = 0; preload_tile < static_cast(qkPreloadNum) && preload_tile < num_tiles_s1; - ++preload_tile) { - if constexpr (DAV_CUBE) { - for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { - qkAccTileEvtID = assign_running_acc_tile(qkAccTile); - compute_qk(preload_tile, sub_tile, q_block, k, qk_tile_fifo_block, - qMatTile[0], kMatTile[k_src_pingpong_id % kMatTNBuffers], - qkAccTile, k_src_pingpong_id % kMatTNBuffers, - qkAccTileEvtID, qk2smSync, block_idx); - k_src_pingpong_id++; - } - } - if constexpr (DAV_VEC) { - for (int row_slice = 0; row_slice < static_cast(kTileFactor); ++row_slice) { - compute_p( - preload_tile, row_slice, qk_tile_fifo_block, exp_max_ififo_block, global_sum_block, exp_max_block, - qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers], x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], - x_expPushT, - input_reduce_tmp, m1_local_max, l1_local_sum, m2_global_max, l2_global_sum, - l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size], triu, p_gu_src_pingpong_id % xexpVecTNBuffers, - qk2smSync, pPipe, block_idx); - p_gu_src_pingpong_id++; - } - } - } -``` - -Inside **`compute_p`**, vec **`qk2smSync.wait()`** / **`free()`** bracket GM visibility: - -```496:520:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp - wait_flag(PIPE_V, PIPE_MTE2, pTileEventId); - if (row_slice == 0) - qk2smSync.wait(); // wait for QK produce data - // ... TLOAD assembly ... - if (row_slice == static_cast(kTileFactor) - 1 && should_notify_consume) - qk2smSync.free(); // notify for SM consume data -``` - -**Desired Python pattern.** Set **`QK_PRELOAD = 4`**, model **`l1_exp_max_ififo[qkp_tile_fifo_size]`**, and emit **named sync** (or dialect ops that lower like **`TSync_Custom`**) aligned with **`should_wait_consumption` / `should_notify_consumption`** from the reference. - -**Current behavior / limits.** - -- **`initialize_l2g2l_pipe`** requires **`local_slot_num > 0`** and a **peer `reserve_buffer`**; vec UB scales with **`FIFO_BYTES_QK ≈ slot_size × local_slot_num`**. There is **no** built-in “vec only **`TLOAD`** from `gm_addr` after cube slot closes” mode with **zero** local ring bytes. -- Python FA relies on **`ptoas --enable-insert-sync`** for cross-kernel ordering instead of **`TSync_Custom`**-style explicit tokens. - -**Un-optimal workaround in `fa_builder.py`.** Hard-code **`QK_PRELOAD = 2`** and a **two-tile `exp_max` ring**; document that raising preload needs more UB and/or hazard rework: - -```65:68:examples/aot/flash_attention/experimental/fa_builder.py -# QK preload depth — must be >= 1; reference launch uses 4, this builder -# keeps 2 for a smaller VEC exp_max ring (see header comment). -# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled. -QK_PRELOAD = 2 -``` - ---- - -### D. Manual UB / MAT layout bookkeeping (request **7**) - -**Reference** hides much of this behind **`allocate_cube_tile_buffers` / `allocate_vec_tile_buffers`** templates (`runTFA`). +**Reasonable Python direction (still 1:1 with ref).** Express the same **GM layout + slice loads** using **`pto.store`** / **`pto.load`** + **`slice_view`** on `__gm__` tensors (or keep **`l2g2l_pipe`** but size **`slot_size`** and **`gm_addr`** offsets to match **`base_elems`** / **`kTileFactor`**). **Do not** introduce **`tile.cvt(qk_acc → mat/left fp16)` + `tpush`** for QK: that is **not** in the reference, and MAT/LEFT **`tpush`** is not a supported producer anyway. -**Desired Python / bindings.** Declarative **“tile name → (space, dtype, shape)”** map that **checks overlaps** and derives **`MAT_P_FIFO_OFF`**, vec FIFO bases, and recv scratch **automatically** when `Cube_S0`, `Tile_S1`, or `HEAD` change. +**Python builder today** (`experimental/fa_builder.py`): one matmul over full `S1_TILE`, **`pto.tpush(qk_acc[k], qk_pipe)`**, vec **`pto.tpop(qk_vec_ty, …)`** — simpler than ref **`kTileFactor`** loop; **documented** in §1–2 as the gap to close **using ref-shaped stores/loads or matching pipe packing**. -**Current behavior.** Authors must align **`reserve_buffer`**, **`import_reserved_buffer`**, **`alloc_tile`**, and **`initialize_l2g2l_pipe.slot_size`** by hand. +### B. P: **V2C `TPipe` with `sizeof(half)`** (ref) vs vec `tile.cvt` + `tpush_to_aic` (Python) -**Concrete workarounds in `fa_builder.py` today.** +**Reference — P FIFO slot is fp16-sized, cube pops into MAT** (`fa_kernel.cpp`): -- **Pad `MAT_P_FIFO_OFF`** so the P V2C FIFO cannot overlap growing MAT tiles when **`S0`** increases: - -```100:109:examples/aot/flash_attention/experimental/fa_builder.py -# Explicit local-memory layout used when compiling with --pto-level=level3. -# Offsets are byte offsets within each independent local address space. -MAT_Q_OFF = 0 -MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2 -MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2 -MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2 -MAT_P_FIFO_OFF = MAT_V_OFF + S1_TILE * HEAD * 2 -# Pad past the last MAT-resident tile; bisheng is sensitive to overlap here. -if MAT_P_FIFO_OFF < 393216: - MAT_P_FIFO_OFF = 393216 -``` - -- **Shrink reduce-tile ring footprint** with a computed stride instead of a fixed **512** bytes per slot: - -```126:127:examples/aot/flash_attention/experimental/fa_builder.py -# Tight packing for reduce / exp_max ring scalars (one column per logical row). -VEC_RED_STRIDE = ((S0_HALF * 4 + 127) // 128) * 128 -``` - -- **Reuse `p_fp32` as `row_max` scratch** so `tmp_tile` remains free for **`row_sum`**, saving one large vec buffer’s worth of peak live data: - -```557:558:examples/aot/flash_attention/experimental/fa_builder.py - tile.muls(qk_recv, scale, qk_recv) - tile.row_max(qk_recv, p_fp32, local_max) +```820:823:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + using PPipe = + TPipe; ``` -- **Single `VEC_RECV_OFF` scratch** sized for the max of half-tile QK (if ever narrowed) and half-tile PV: +**Python today** uses **`tile.cvt(p_fp32, p_fp16)`** then **`pto.tpush_to_aic(p_fp16, …)`** — that is a **porting of the vec→cube half path**, not the QK cube→vec path. Any **`TCvt`-like behavior for P** should stay aligned with **`pto_macro_fa_softmax`** / reference packing, not used as a precedent to invent QK dtype tricks. -```134:136:examples/aot/flash_attention/experimental/fa_builder.py -# Shared recv scratch: max(fp16 QK half-tile, fp32 PV half-tile) for tpop addr=. -_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 2, S0_HALF * HEAD * 4) -VEC_RECV_OFF = VEC_RED_BASE_OFF + 6 * VEC_RED_STRIDE -``` +### C. **`PIPE_UNASSIGNED` on non-ACC `tpush`** -These are **correctness-preserving micro-optimizations**; they do **not** replace dialect support for **§A–§C**. +If someone tries **`pto.tpush`** from a **MAT** or **LEFT** tile, verification fails (`include/PTO/IR/PTOOps.td`, `TPushOp::getPipe`). That is **consistent** with there being **no** ref-style **`TSTORE`** from MAT/LEFT to GM for QK. **Not a PTOAS feature request for FA**—use **ACC → `tpush`** (pipe) or **`pto.store`** from acc/global views like **`TSTORE`**. --- -## Priority (suggested for FA parity) +## Priority (suggested) -| Priority | Item | Why | -|----------|------|-----| -| P0 | **1** (non-ACC `tpush`) + **2** (slot/dtype decouple) | Unblocks **fp16-on-wire** and smaller vec FIFO without losing fp32 matmul. | -| P0 | **3** (K-split / partial QK) | Matches **reference cube + vec geometry**; largest structural mismatch today. | -| P1 | **5** (GM-primary / smaller local ring) | Unlocks **`QK_PRELOAD = 4`**-class schedules without linear growth of vec **`reserve_buffer`**. | -| P1 | **6** (explicit sync) | Needed for **faithful** backpressure / CV parity when scaling blocks/cores. | -| P2 | **4** (richer split) | Reduces pressure on Python to fake **row_slice** with full-height tiles. | -| P2 | **7–8** (bindings + docs) | Reduces integration risk and documents the **intended** lowering contract. | +| Priority | Item | Rationale | +|----------|------|-----------| +| P0 | **§1–2** (docs + ref-shaped K/`Vec_S0` lowering recipes) | Largest **semantic** gap vs ref **without** inventing ops. | +| P1 | **§4** (sync parity / equivalence vs `TSync_Custom`) | Exists in ref; **`--enable-insert-sync`** is the only deliberate deviation today. | +| P2 | **§3, §5–6** (indexing examples, ergonomics, macro cross-links) | Reduces author error; no new hardware paths. | --- From 4e71a1504b57cef614b1f4ec07476c18518c8118 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 28 Apr 2026 08:57:24 +0200 Subject: [PATCH 6/7] update ptoas gap --- .../experimental/fa_builder.py | 79 ++++++++++++------- examples/aot/flash_attention/known_gap.md | 17 ++-- examples/aot/flash_attention/ptoas_request.md | 6 +- 3 files changed, 63 insertions(+), 39 deletions(-) diff --git a/examples/aot/flash_attention/experimental/fa_builder.py b/examples/aot/flash_attention/experimental/fa_builder.py index 664ecc46..e398cb9b 100644 --- a/examples/aot/flash_attention/experimental/fa_builder.py +++ b/examples/aot/flash_attention/experimental/fa_builder.py @@ -6,6 +6,15 @@ # # constexpr int qkPreloadNum = 2; // warmup depth (reference uses 4; UB-limited here) # +# Cube K tiling matches ref `Cube_S1` / `kTileFactor`: two matmuls per logical +# `S1_TILE` tile into `qk_acc` column subviews; K MAT uses ping-pong buffers +# (`kMatTNBuffers = 2` in ref). +# +# PV stays one `p_left × v_right` matmul: `tmatmul` requires **LEFT** lhs, but +# boxed RowMajor **LEFT** forbids P column `tile.subview` (`keep full cols`); +# MAT P allows column subviews yet cannot be `tmatmul` lhs — ref `compute_pv` +# K-strip schedule needs a future layout/op story in PTO DSL. +# # /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec softmaxes them. */ # # /* Steady state, tile_id 0..N-1: @@ -50,7 +59,11 @@ S0 = int(os.environ.get("FA_S0", "128")) S0_HALF = S0 // 2 # rows per AIV sub-block (TILE_UP_DOWN split) HEAD = 128 # attention head dimension -S1_TILE = 256 # K/V columns per tile +S1_TILE = 256 # K/V columns per tile (logical `Tile_S1`; ref `fa_kernel.cpp`) +# Reference: `Cube_S1` K micro-tile and `kTileFactor = Tile_S1 / Cube_S1` matmul passes. +CUBE_S1 = int(os.environ.get("FA_CUBE_S1", "128")) +assert S1_TILE % CUBE_S1 == 0, "S1_TILE must be divisible by FA_CUBE_S1" +K_TILE_FACTOR = S1_TILE // CUBE_S1 # NUM_TILES is overridable via the FA_NUM_TILES env var so the same builder # can produce kernels for different sequence lengths # (S1_TOTAL = S1_TILE * NUM_TILES). @@ -100,8 +113,10 @@ # Explicit local-memory layout used when compiling with --pto-level=level3. # Offsets are byte offsets within each independent local address space. MAT_Q_OFF = 0 -MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2 -MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2 +# Two physical K tiles (ref `kMatTNBuffers = 2`) for ping-pong loads. +MAT_K0_OFF = MAT_Q_OFF + S0 * HEAD * 2 +MAT_K1_OFF = MAT_K0_OFF + HEAD * CUBE_S1 * 2 +MAT_P_RECV_OFF = MAT_K1_OFF + HEAD * CUBE_S1 * 2 MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2 MAT_P_FIFO_OFF = MAT_V_OFF + S1_TILE * HEAD * 2 # Pad past the last MAT-resident tile; bisheng is sensitive to overlap here. @@ -131,8 +146,8 @@ VEC_LOCAL_SUM_OFF = VEC_RED_BASE_OFF + 3 * VEC_RED_STRIDE VEC_EXP_MAX_A_OFF = VEC_RED_BASE_OFF + 4 * VEC_RED_STRIDE VEC_EXP_MAX_B_OFF = VEC_RED_BASE_OFF + 5 * VEC_RED_STRIDE -# Shared recv scratch: max(fp16 QK half-tile, fp32 PV half-tile) for tpop addr=. -_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 2, S0_HALF * HEAD * 4) +# Shared recv scratch: max(fp32 QK half-tile, fp32 PV half-tile) for tpop addr=. +_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 4, S0_HALF * HEAD * 4) VEC_RECV_OFF = VEC_RED_BASE_OFF + 6 * VEC_RED_STRIDE ID_QK = 10 # Cube → Vec, dir_mask = 1 (uses lower-level l2g2l) @@ -158,6 +173,7 @@ def meta_data(): q_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp16) kt_sub_ty = pto.SubTensorType(shape=[HEAD, S1_TILE], dtype=fp16) + kt_sub_slice_ty = pto.SubTensorType(shape=[HEAD, CUBE_S1], dtype=fp16) v_sub_ty = pto.SubTensorType(shape=[S1_TILE, HEAD], dtype=fp16) o_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp32) o_sub_half_ty = pto.SubTensorType(shape=[S0_HALF, HEAD], dtype=fp32) @@ -166,13 +182,13 @@ def meta_data(): q_mat_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="MAT") q_left_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="LEFT") k_mat_ty = pto.TileBufType( - shape=[HEAD, S1_TILE], + shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="MAT", config=pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor"), ) k_right_ty = pto.TileBufType( - shape=[HEAD, S1_TILE], dtype=fp16, memory_space="RIGHT" + shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="RIGHT" ) qk_acc_ty = pto.TileBufType(shape=[S0, S1_TILE], dtype=fp32, memory_space="ACC") p_recv_ty = pto.TileBufType( @@ -236,6 +252,7 @@ def cube_kernel( cHEAD = const(HEAD) cS1_TILE = const(S1_TILE) cS1_TOTAL = const(S1_TOTAL) + cCUBE_S1 = const(CUBE_S1) cNUM_TILES = const(NUM_TILES) cNUM_Q_BLOCKS = const(NUM_Q_BLOCKS) @@ -304,7 +321,8 @@ def cube_kernel( right_base = const(RIGHT_KV_OFF, s.int64) q_mat = pto.alloc_tile(q_mat_ty, addr=const(MAT_Q_OFF, s.int64)) q_left = pto.alloc_tile(q_left_ty, addr=const(LEFT_Q_OFF, s.int64)) - k_mat_s = pto.alloc_tile(k_mat_ty, addr=const(MAT_K_OFF, s.int64)) + k_mat_0 = pto.alloc_tile(k_mat_ty, addr=const(MAT_K0_OFF, s.int64)) + k_mat_1 = pto.alloc_tile(k_mat_ty, addr=const(MAT_K1_OFF, s.int64)) k_right_s = pto.alloc_tile(k_right_ty, addr=right_base) qk_acc_s = pto.alloc_tile(qk_acc_ty, addr=const(ACC_QK_OFF, s.int64)) p_recv_s = pto.alloc_tile(p_recv_ty, addr=const(MAT_P_RECV_OFF, s.int64)) @@ -312,7 +330,7 @@ def cube_kernel( v_mat_s = pto.alloc_tile(v_mat_ty, addr=const(MAT_V_OFF, s.int64)) v_right_s = pto.alloc_tile(v_right_ty, addr=right_base) pv_acc_s = pto.alloc_tile(pv_acc_ty, addr=const(ACC_PV_OFF, s.int64)) - k_mat = [k_mat_s, k_mat_s] + k_mat = [k_mat_0, k_mat_1] k_right = [k_right_s, k_right_s] qk_acc = [qk_acc_s, qk_acc_s] p_recv = [p_recv_s, p_recv_s] @@ -338,6 +356,25 @@ def cube_kernel( strides=[cHEAD, c1], ) + # Two `Cube_S1`-wide matmuls per logical tile (ref `compute_qk` / `kTileFactor`). + def accumulate_qk_for_tile(k_s1_offset, qk_buf, k_mat_buf_idx): + for sc in range(K_TILE_FACTOR): + k_col = k_s1_offset + const(sc * CUBE_S1) + kt_view = pto.slice_view( + kt_sub_slice_ty, + source=tv_k, + offsets=[c0, k_col], + sizes=[cHEAD, cCUBE_S1], + ) + pto.load(kt_view, k_mat[k_mat_buf_idx]) + tile.mov(k_mat[k_mat_buf_idx], k_right[k_mat_buf_idx]) + qk_part = tile.subview( + qk_acc[qk_buf], + [c0, const(sc * CUBE_S1)], + [S0, CUBE_S1], + ) + tile.matmul(q_left, k_right[k_mat_buf_idx], qk_part) + for qb in pto.range(qb_start, qb_end, c1): q_row_off = qb * cS0 @@ -349,16 +386,8 @@ def cube_kernel( # =================== Cube prologue: emit QK[0..QK_PRELOAD-1] =================== for k in range(QK_PRELOAD): - k_off = const(k * S1_TILE) - kt_view_k = pto.slice_view( - kt_sub_ty, - source=tv_k, - offsets=[c0, k_off], - sizes=[cHEAD, cS1_TILE], - ) - pto.load(kt_view_k, k_mat[k]) - tile.mov(k_mat[k], k_right[k]) - tile.matmul(q_left, k_right[k], qk_acc[k]) + k_s1_off = const(k * S1_TILE) + accumulate_qk_for_tile(k_s1_off, k, k % 2) pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) # Preload V[0] for the very first PV. @@ -374,14 +403,7 @@ def cube_kernel( # Pair-unrolled; buffer index b = t % 2 (logical ping-pong). def emit_cube_step(t_idx, b): next_qk = t_idx + const(QK_PRELOAD) - kt_off = next_qk * cS1_TILE - kt_view = pto.slice_view( - kt_sub_ty, - source=tv_k, - offsets=[c0, kt_off], - sizes=[cHEAD, cS1_TILE], - ) - pto.load(kt_view, k_mat[b]) + k_s1_off = next_qk * cS1_TILE p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) tile.mov(p_raw, p_left[b]) @@ -400,8 +422,7 @@ def emit_cube_step(t_idx, b): tile.matmul(p_left[b], v_right[b], pv_acc[b]) pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN) - tile.mov(k_mat[b], k_right[b]) - tile.matmul(q_left, k_right[b], qk_acc[b]) + accumulate_qk_for_tile(k_s1_off, b, b) pto.tpush(qk_acc[b], qk_pipe, SPLIT_UP_DOWN) assert (NUM_TILES - QK_PRELOAD) % 2 == 0 diff --git a/examples/aot/flash_attention/known_gap.md b/examples/aot/flash_attention/known_gap.md index 8b965fab..4e60b885 100644 --- a/examples/aot/flash_attention/known_gap.md +++ b/examples/aot/flash_attention/known_gap.md @@ -11,9 +11,9 @@ This document compares the AOT flash-attention builders (`fa_builder.py`, `exper ### Primary performance gaps (largest expected impact) -1. **Cube tiling and S1 sub-tiling** - Reference: `CUBE_S0 = 128`, `CUBE_S1 = 128`, `TILE_S1 = 256`, so **`kTileFactor = 2`** (two 128-wide K slices per logical 256-wide tile). - DSL (experimental): **`S0 = 128`** by default (env `FA_S0`), still a **single** **`S1_TILE = 256`** matmul per tile (no K-split). The row-block size now matches reference **M**; the remaining gap is **K micro-tiling / matmul overlap** versus the reference’s two `Cube_S1` passes per logical tile. +1. **Cube tiling and S1 sub-tiling (QK vs PV)** + Reference: `CUBE_S0 = 128`, `CUBE_S1 = 128`, `TILE_S1 = 256`, **`kTileFactor = 2`** on **both** **`compute_qk`** and **`compute_pv`** with **`AccMode`** across subtiles. + DSL (experimental): **`S0 = 128`** (env `FA_S0`), **`FA_CUBE_S1`** default **128** — **QK** now runs **`kTileFactor`** matmuls into **`qk_acc`** column **`tile.subview`s** (matches ref **`compute_qk`** geometry). **PV** is still **one** full **`S1_TILE`**-wide matmul per step (no ref-style **`compute_pv`** K-striping yet). Remaining throughput gap vs fused ref is mostly **PV / buffer depth / preload**, not “missing QK K-split”. 2. **Real L1 double-buffering** Reference: `kMatTNBuffers = 2`, `pMatTNBuffers = 2`, `vMatTNBuffers = 2`, plus vec-side ping-pong (`srcVecTNBuffers`, `xexpVecTNBuffers`, `outOTileNBuffers`). @@ -49,18 +49,18 @@ Closing the gap should **not** rely on turning sync insertion off; it should rel ## Progress log (experimental `fa_builder.py`, Apr 2026) -Measured on NPU via `experimental/run.py` (Q=2048, H=128, S1_TILE=256): kernel holds ~24–26 TFLOP/s vs ~60+ TFLOP/s for `torch_npu` fused ref on the same script; correctness (`assert_close` at `run.py:151`) remains the gate. +Measured on NPU via `experimental/run.py` (Q=2048, H=128, S1_TILE=256): with QK **`kTileFactor=2`**, **~25 TFLOP/s** at S1=8192 (Apr 2026; run-to-run variance ±~1 TFLOP/s) vs **~61 TFLOP/s** for `torch_npu` fused ref on the same script; correctness (`assert_close`) remains the gate. | Change attempted | Result | |------------------|--------| | `QK_PRELOAD=4` + four `exp_max` slots + quad-unrolled vec/cube + true dual MAT banks for K/P/V | AICore CCU address fault (`mte`/`ccu`); likely VEC `tpop` scratch / expanded red region + dual `RIGHT` typing; **reverted**. | | True L1 ping-pong (separate `MAT_K0`/`MAT_K1`, …) without preload-4 | Overlapped `MAT_P_FIFO` with tail tiles until `MAT_P_FIFO_OFF` was recomputed; still faulted with dual `RIGHT` `alloc_tile` until fully reverted to aliased single-buffer layout. | -| K-split: two `CUBE_S1=128` matmuls per tile via `tile.subview` on `qk_acc` | Builds and passes `assert_close`; **~7% slower** than one `S1_TILE=256` matmul on this target — **reverted**. | +| QK K-split: two `CUBE_S1=128` matmuls per tile via `tile.subview` on `qk_acc` | **Landed** for ref **`compute_qk`** parity; `tile.subview` **`sizes`** must be Python **`int`**, not `s.const` (MLIR `I64ArrayAttr`). **PV** ref-style **`kTileFactor`** is **blocked**: **`tmatmul` lhs must be LEFT**, but LEFT P cannot column-**`subview`**; MAT P can **`subview`** but cannot be **`tmatmul` lhs** (see `ptoas_request.md`). | | Reorder steady cube step (PV before K load) | **Slight regression** vs original order on sampled runs — **reverted**. | -**Takeaway:** With **`S0=128`** landed in the experimental builder, the largest remaining structural gaps versus the reference are **`kTileFactor` / `CUBE_S1` K-split**, **preload / ring depth (`QK_PRELOAD`, CV FIFO)**, and **vec working-tile geometry (`Vec_S0`)** relative to what `l2g2l_pipe` + `TILE_UP_DOWN` imply for UB. Raising `QK_PRELOAD` still needs more `exp_max` slots and recv/GM budget. +**Takeaway:** With **`S0=128`** and **QK `kTileFactor`** aligned to the reference, the largest remaining structural gaps versus the reference are **`compute_pv` K-split / `AccMode`**, **preload / ring depth (`QK_PRELOAD`, CV FIFO)**, **true L1 ping-pong**, and **vec working-tile geometry (`Vec_S0`)** relative to what `l2g2l_pipe` + `TILE_UP_DOWN` imply for UB. Raising `QK_PRELOAD` still needs more `exp_max` slots and recv/GM budget. -| `S0=128` (Apr 2026) | Default `S0` raised to **128** (`FA_S0`). Cube **`pto.tpush`** uses **`ACC`** QK tiles into `l2g2l_pipe` (same as today’s supported producer); the **reference** path is **`TSTORE`** from acc to **fp32 GM**, not MAT/LEFT→GM—**`PIPE_UNASSIGNED` for MAT/LEFT `tpush` is expected**, not a toolchain defect to “fix” for FA. Builder keeps **fp32 `SLOT_SIZE_QK`**, vec **`p_fp32` as `row_max` scratch**, and a **shared `VEC_RECV_OFF`** for half-tile **`tpop`** scratch. `experimental/run.py` + `compile.sh` pass on NPU at ~24 TFLOP/s (order-of-magnitude below fused ref). | +| `S0=128` + QK `kTileFactor` (Apr 2026) | Default `S0` **128** (`FA_S0`), **`FA_CUBE_S1`** default **128**, QK via **`tile.subview`** acc columns. Cube **`pto.tpush`** uses **`ACC`** into `l2g2l_pipe`; ref uses **`TSTORE`** to fp32 GM—**`PIPE_UNASSIGNED` for MAT/LEFT `tpush` is expected**. **`compile.sh`** + **`run.py`** pass on NPU; **~25 TFLOP/s** at 8k vs fused ref **~61 TFLOP/s**. | --- @@ -78,7 +78,8 @@ Use this as a work backlog; order roughly reflects suggested priority (tiling/bu ### Tiling and cube schedule -- [ ] **Match reference cube geometry:** `CUBE_S0=128`, `CUBE_S1=128`, `TILE_S1=256`, and **`kTileFactor`** loop (two K slices per 256-wide tile) in the DSL builder’s cube kernel, or justify an equivalent FLOP/memory contract with measurements. *(Prototype K-split only: numerics OK, throughput down on current NPU.)* +- [x] **QK: match reference `kTileFactor` loop** — two **`Cube_S1`** K slices per logical tile into **`qk_acc`** (`experimental/fa_builder.py`, env **`FA_CUBE_S1`**). +- [ ] **PV: match reference `compute_pv`** — **`kTileFactor`** partial matmuls / **`AccMode`**. Blocked in Python today: **`tmatmul` requires LEFT lhs**; LEFT P rejects column **`tile.subview`**; MAT P allows **`subview`** but is not a legal **`tmatmul` lhs** (see **`ptoas_request.md`** §2). - [x] **Match reference `CUBE_S0` (128) in experimental builder** — default `S0=128` via `FA_S0` (Apr 2026); UB layout was tightened (shared `tpop` recv sizing, `row_max` scratch reuse, smaller `VEC_RED_STRIDE`). Smaller blocks remain available with `FA_S0=32` etc. if needed. - [ ] **Align `QK_PRELOAD`** with the reference launch (**4**) and extend the **`exp_max` / GU ring** logic (or equivalent hazard avoidance) for that depth; assert fifo and UB sizing. diff --git a/examples/aot/flash_attention/ptoas_request.md b/examples/aot/flash_attention/ptoas_request.md index 4ee89700..36af8b42 100644 --- a/examples/aot/flash_attention/ptoas_request.md +++ b/examples/aot/flash_attention/ptoas_request.md @@ -30,11 +30,13 @@ Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. **What the reference does.** `kTileFactor = Tile_S1 / Cube_S1`; two matmul passes per logical tile; vec softmax tile **`Vec_S0 × Tile_S1`** with **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (`runTFA`); **`compute_p`** loads **`kTileFactor`** column strips from GM. -**What the Python builder does today.** One **`HEAD × S1_TILE`** matmul and one full **`S0 × S1_TILE`** `tpush`; vec **`tpop`** **`S0_HALF × S1_TILE`** with **`TILE_UP_DOWN`**. +**What the Python builder does today (experimental).** **`compute_qk`**: **`kTileFactor`** passes (**`HEAD × Cube_S1`** K slices, **`tile.subview`** on **`qk_acc`** for **`S0 × Cube_S1`** acc columns), then one full **`S0 × S1_TILE`** `tpush`. **`compute_pv`**: still a **single** **`p_left × v_right → pv_acc`** matmul per tile (no K-split / **`AccMode`** striping like the reference yet). Vec **`tpop`** **`S0_HALF × S1_TILE`** with **`TILE_UP_DOWN`**. **Ask.** -- **Documented** recipes (and, if useful, **ptodsl** helpers only—no new PTO ops required) for: cube **`AccMode`/`InitPartialSum`/`AccPartialSum`**-style sequences matching `compute_qk` + `compute_pv`, and vec **`pto.load`** / **`slice_view`** patterns matching **`TLOAD`** + **`TASSIGN`** column offsets in `compute_p`. +- **Documented** recipes (and, if useful, **ptodsl** helpers only—no new PTO ops required) for: cube **`AccMode`/`InitPartialSum`/`AccPartialSum`**-style sequences matching **`compute_pv`** (and any edge cases in **`compute_qk`**), and vec **`pto.load`** / **`slice_view`** patterns matching **`TLOAD`** + **`TASSIGN`** column offsets in `compute_p`. +- **`tile.subview`** Python API: **`sizes`** must be plain **`int`** (they are forwarded to MLIR `I64ArrayAttr`); **`const()`** wrappers currently fail at build time—document or unwrap in **`tile.subview`**. +- **`compute_pv` K-split in Python:** ref uses **`kTileFactor`** partial matmuls with **P** column strips. **`pto.tmatmul`** requires **`lhs` in LEFT**, but **LEFT** `tile.subview` on default boxed RowMajor **P** rejects column strips (`boxed RowMajor subview must keep full cols`). **MAT** **`p_recv`** allows the same column **`subview`** pattern as **`qk_acc`**, but **MAT cannot be `tmatmul` lhs** today. Reasonable ask: either **document** a supported recipe (e.g. **`TMOV`** staging + layout) or **relax verifier** / **extend `tmatmul`** only where it matches the reference matmul macro contract—not an invented new algorithm. - If something in **MLIR verification** blocks a **literal** ref-shaped **`pto.store`**/`load` schedule that is otherwise valid, file that as a **narrow bugfix** with a ref citation—not a new feature. --- From de19b211b5ae4ca5f9d20358f2590451c6a63441 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Tue, 28 Apr 2026 09:04:20 +0200 Subject: [PATCH 7/7] update ptoas requests --- examples/aot/flash_attention/ptoas_request.md | 270 +++++++++++++++++- 1 file changed, 263 insertions(+), 7 deletions(-) diff --git a/examples/aot/flash_attention/ptoas_request.md b/examples/aot/flash_attention/ptoas_request.md index 36af8b42..d19dbc52 100644 --- a/examples/aot/flash_attention/ptoas_request.md +++ b/examples/aot/flash_attention/ptoas_request.md @@ -82,9 +82,11 @@ Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. ## Concrete examples (reference ↔ Python today) +These snippets are for **PTOAS / `ptodsl` authors** mapping hand-written C++ (`fa_kernel.cpp`) to Python builders. Line citations use this repo’s paths. + ### A. QK: **`TSTORE`/`TLOAD`** (ref) vs **`tpush`/`tpop`** + ACC (Python) -**Reference — cube writes fp32 QK slices to GM** (`examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp`): +**Reference — cube writes one `Cube_S0 × Cube_S1` fp32 strip per `sub_tile_id` to GM** (`compute_qk`): ```364:376:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp using GlobalDataQK = @@ -98,7 +100,7 @@ Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. TSTORE(qkGlobalTile, qkAccTile); ``` -**Reference — vec reads fp32 from GM into column strips of `qkVecTile`** (same file, `compute_p`): +**Reference — vec reads `kTileFactor` GM strips and packs them into one wide `qkVecTile`** (`compute_p`): ```505:517:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp for (int sub_col = 0; sub_col < static_cast(kTileFactor); ++sub_col) { @@ -113,11 +115,104 @@ Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. } ``` -**Reasonable Python direction (still 1:1 with ref).** Express the same **GM layout + slice loads** using **`pto.store`** / **`pto.load`** + **`slice_view`** on `__gm__` tensors (or keep **`l2g2l_pipe`** but size **`slot_size`** and **`gm_addr`** offsets to match **`base_elems`** / **`kTileFactor`**). **Do not** introduce **`tile.cvt(qk_acc → mat/left fp16)` + `tpush`** for QK: that is **not** in the reference, and MAT/LEFT **`tpush`** is not a supported producer anyway. +**Python builder today — cube:** `kTileFactor` **inner** matmuls match the ref’s **K** tiling, but the **GM path** is still one **`ACC`** tile **`tpush`** per logical tile (full `S0 × S1_TILE` fp32), not **`kTileFactor`** separate **`TSTORE`** slots. + +`accumulate_qk_for_tile` (K strips → **`qk_acc`** column subviews): + +```359:376:examples/aot/flash_attention/experimental/fa_builder.py + def accumulate_qk_for_tile(k_s1_offset, qk_buf, k_mat_buf_idx): + for sc in range(K_TILE_FACTOR): + k_col = k_s1_offset + const(sc * CUBE_S1) + kt_view = pto.slice_view( + kt_sub_slice_ty, + source=tv_k, + offsets=[c0, k_col], + sizes=[cHEAD, cCUBE_S1], + ) + pto.load(kt_view, k_mat[k_mat_buf_idx]) + tile.mov(k_mat[k_mat_buf_idx], k_right[k_mat_buf_idx]) + qk_part = tile.subview( + qk_acc[qk_buf], + [c0, const(sc * CUBE_S1)], + [S0, CUBE_S1], + ) + tile.matmul(q_left, k_right[k_mat_buf_idx], qk_part) +``` + +Prologue: one **`tpush`** per logical tile after **`accumulate_qk_for_tile`**: + +```387:391:examples/aot/flash_attention/experimental/fa_builder.py + for k in range(QK_PRELOAD): + k_s1_off = const(k * S1_TILE) + accumulate_qk_for_tile(k_s1_off, k, k % 2) + pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) +``` + +**Python builder today — vec:** one **`tpop`** of a **half-tile** `qk_vec_ty` (`S0_HALF × S1_TILE` fp32) per softmax step, not **`Vec_S0 × Cube_S1`** repeated **`TLOAD`**s: + +```571:605:examples/aot/flash_attention/experimental/fa_builder.py + def emit_softmax_step(exp_max_slot, is_init): + qk_recv = pto.tpop( + qk_vec_ty, + qk_pipe, + SPLIT_UP_DOWN, + addr=const(VEC_RECV_OFF, s.int64), + ) + tile.muls(qk_recv, scale, qk_recv) + tile.row_max(qk_recv, p_fp32, local_max) + # ... softmax body ... + tile.cvt(p_fp32, p_fp16) + pto.tpush_to_aic(p_fp16, SPLIT_UP_DOWN, id=ID_P) + pto.tfree(qk_pipe, SPLIT_UP_DOWN) +``` + +**Still missing vs ref (PTOAS / docs ask).** Either **document** how **`slot_size` / GM offsets** should reproduce **`base_elems` + `sub_tile_id`** packing, or show **`pto.store`/`pto.load`** that mirror **`TSTORE`/`TLOAD`** exactly. **Do not** document **`tile.cvt` on QK + `tpush` from MAT/LEFT`**; that is not the reference QK path. + +### A2. Schedule: **`runTFA`’s `sub_tile` loop** (C++) vs **fused Python steps** (same semantics, different shape) -**Python builder today** (`experimental/fa_builder.py`): one matmul over full `S1_TILE`, **`pto.tpush(qk_acc[k], qk_pipe)`**, vec **`pto.tpop(qk_vec_ty, …)`** — simpler than ref **`kTileFactor`** loop; **documented** in §1–2 as the gap to close **using ref-shaped stores/loads or matching pipe packing**. +The reference **interleaves** cube **`compute_qk`** and vec **`compute_p`** at **`kTileFactor`** granularity inside the steady loop: -### B. P: **V2C `TPipe` with `sizeof(half)`** (ref) vs vec `tile.cvt` + `tpush_to_aic` (Python) +```848:920:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + for (int preload_tile = 0; preload_tile < static_cast(qkPreloadNum) && preload_tile < num_tiles_s1; + ++preload_tile) { + if constexpr (DAV_CUBE) { + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + compute_qk<...>(preload_tile, sub_tile, ...); + } + } + if constexpr (DAV_VEC) { + for (int row_slice = 0; row_slice < static_cast(kTileFactor); ++row_slice) { + compute_p<...>(preload_tile, row_slice, ...); + } + } + } + + for (int tile_id = 0; tile_id < num_tiles_s1; ++tile_id) { + // ... + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + if constexpr (DAV_CUBE) { + if (next_qk_tile != -1) { + compute_qk<...>(next_qk_tile, sub_tile, ...); + } + } + if constexpr (DAV_VEC) { + if (next_qk_tile != -1) { + compute_p<...>(next_qk_tile, sub_tile, ...); + } + } + if constexpr (DAV_CUBE) { + compute_pv<...>(tile_id, sub_tile, ...); + } + } + if constexpr (DAV_VEC) { + compute_gu<...>(tile_id, ...); + } + } +``` + +**Python** (`experimental/fa_builder.py`) **fuses** all **`kTileFactor`** QK matmuls **inside** `accumulate_qk_for_tile` before **`tpush`**, and vec **`emit_softmax_step`** does **not** take explicit **`row_slice`** or **`sub_col`** arguments—**`TILE_UP_DOWN`** replaces part of the ref’s **`row_slice × Vec_S0`** story. A **PTOAS-facing doc** should spell out: “**`row_slice` loop** ↔ **`get_subblock_idx()` + fixed `S0_HALF`**” and “**`sub_col` GM loads** ↔ **either** replicated **`slice_view`** **or** one wide **`tpop`**,” so reviewers do not assume bit-identical control flow. + +### B. P: **V2C `TPipe` with `sizeof(half)`** (ref) vs vec **`tile.cvt`** + **`tpush_to_aic`** (Python) **Reference — P FIFO slot is fp16-sized, cube pops into MAT** (`fa_kernel.cpp`): @@ -126,11 +221,172 @@ Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. TPipe; ``` -**Python today** uses **`tile.cvt(p_fp32, p_fp16)`** then **`pto.tpush_to_aic(p_fp16, …)`** — that is a **porting of the vec→cube half path**, not the QK cube→vec path. Any **`TCvt`-like behavior for P** should stay aligned with **`pto_macro_fa_softmax`** / reference packing, not used as a precedent to invent QK dtype tricks. +**Python — same logical edge (vec → cube), different primitive names:** + +```603:604:examples/aot/flash_attention/experimental/fa_builder.py + tile.cvt(p_fp32, p_fp16) + pto.tpush_to_aic(p_fp16, SPLIT_UP_DOWN, id=ID_P) +``` + +**Cube consumer** still **`tpop`**s into **`p_recv_ty`** (MAT), then **`mov`** to **`p_left`** for **`matmul`**—analogous to **`TPOP`** into **`pMatTile`**. The **ask for PTOAS** is macro-order parity (**`pto_macro_fa_softmax`**) and **slot byte size** documentation, not new QK **`TCvt`** ideas. ### C. **`PIPE_UNASSIGNED` on non-ACC `tpush`** -If someone tries **`pto.tpush`** from a **MAT** or **LEFT** tile, verification fails (`include/PTO/IR/PTOOps.td`, `TPushOp::getPipe`). That is **consistent** with there being **no** ref-style **`TSTORE`** from MAT/LEFT to GM for QK. **Not a PTOAS feature request for FA**—use **ACC → `tpush`** (pipe) or **`pto.store`** from acc/global views like **`TSTORE`**. +**What fails today (do not suggest as FA “fix”):** + +```python +# Hypothetical / INVALID for QK in current PTO lowering: +pto.tpush(some_left_tile, qk_pipe, split=1) # LEFT → pipe: not a supported QK producer +pto.tpush(some_mat_tile, qk_pipe, split=1) # MAT → pipe: verifier / PIPE_UNASSIGNED +``` + +**What the Python FA builder uses instead (supported):** + +```python +pto.tpush(qk_acc_tile, qk_pipe, SPLIT_UP_DOWN) # ACC → l2g2l pipe (supported producer) +``` + +That matches the **intent** of ref **`TSTORE`** from **accumulator** data to a **visibility** buffer, even though the **packing** still differs from per-**`sub_tile_id`** **`TSTORE`** (§A). + +### D. PV: **`AccMode` + `kTileFactor`** in C++ vs **single `matmul`** in Python (verifier gap) + +**Reference — outer caller invokes `compute_pv` once per `sub_tile`** (`runTFA`): + +```912:918:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + if constexpr (DAV_CUBE) { + compute_pv( + tile_id, sub_tile, v, pv_tile_fifo_block, pMatTile[pv_src_pingpong_id % pMatTNBuffers], + vMatTile[pv_src_pingpong_id % vMatTNBuffers], pvAccTile, + pv_src_pingpong_id % vMatTNBuffers + PV_EVENT_ID0, pvAccTileEvtID, pPipe, pv2guSync); + pv_src_pingpong_id++; + } +``` + +**Reference — inner `AccMode` chooses init vs accumulate across K strips** (`compute_pv`): + +```400:433:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + const int s1_index = tile_id * static_cast(Tile_S1) + sub_tile_id * static_cast(Cube_S1); + // ... + GlobalVT vLoad((__gm__ half *)(v + s1_index * HEAD_SIZE)); + TLOAD(vMatTile, vLoad); + + TPOP(pPipe, pMatTile); + + const AccMode accMode = (sub_tile_id == 0) ? + (is_last_subtile ? AccMode::InitFinalSum : AccMode::InitPartialSum) : + (is_last_subtile ? AccMode::AccFinalSum : AccMode::AccPartialSum); + pto_macro_matmul(pMatTile, vMatTile, pvAccTile, accMode); +``` + +**What a literal Python port would look like** (conceptual; **not** all verifiable today): + +```python +# Desired shape: P is MAT, V is MAT, pv is ACC — ref uses pMatTile strips × v strips. +for sc in range(K_TILE_FACTOR): + p_sub = tile.subview(p_mat, [0, sc * CUBE_S1], [S0, CUBE_S1]) # column strip of P on MAT + v_sub = tile.subview(v_right, [sc * CUBE_S1, 0], [CUBE_S1, HEAD]) # row strip of V on RIGHT + if sc == 0: + tile.matmul(p_sub, v_sub, pv_acc) + else: + tile.matmul_acc(pv_acc, p_sub, v_sub, pv_acc) +``` + +**What actually blocks this in `ptodsl` today:** + +- **`tile.subview(p_left, …, [S0, CUBE_S1])`** (column strip on default boxed **LEFT** RowMajor **P**) → MLIR: **`boxed RowMajor subview must keep full cols`**. +- **`tile.matmul(p_sub_mat, v_sub, …)`** with **`p_sub_mat`** on **MAT** → MLIR: **`tmatmul` expects lhs in LEFT`**. + +So the **missing PTOAS / dialect / doc** piece is a **supported** way to express **ref `pto_macro_matmul` + `AccMode`** for **PV**—either **documented staging** (**`TMOV`** MAT→LEFT strips with a legal layout, or **`tmatmul`** generalization **only** where it matches the macro contract), not a new FA algorithm. + +**Python steady-state PV today** (one full matmul, ref-equivalent **math**, different **micro-schedule**): + +```408:423:examples/aot/flash_attention/experimental/fa_builder.py + p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) + tile.mov(p_raw, p_left[b]) + pto.tfree_from_aiv(SPLIT_UP_DOWN, id=ID_P) + tile.mov(v_mat[b], v_right[b]) + # ... prefetch next V into v_mat[1 - b] ... + tile.matmul(p_left[b], v_right[b], pv_acc[b]) + pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN) +``` + +### E. **`tile.subview` sizes: Python `int` vs `s.const` (MLIR `I64ArrayAttr`)** + +**Fails at Python build time** (`TypeError` from `IntegerAttr.get`): + +```python +cS0 = const(S0) +cCUBE_S1 = const(CUBE_S1) +qk_part = tile.subview(qk_acc, [c0, const(sc * CUBE_S1)], [cS0, cCUBE_S1]) # BAD: sizes are Value wrappers +``` + +**Works** (sizes are plain integers; offsets can stay dynamic via **`const`** / scalars as today): + +```python +qk_part = tile.subview(qk_acc, [c0, const(sc * CUBE_S1)], [S0, CUBE_S1]) +``` + +**Ask:** either **unwrap** **`const`** in **`ptodsl.api.tile.subview`** for **`sizes`**, or **document** this in PTOAS / `ptodsl` API reference so FA-style builders do not rediscover it via traceback. + +### F. Sync: **`TSync_Custom` + `wait`/`allocate`/`record`/`free`** (C++) vs **pipes + `tfree`** (Python) + +**Reference — explicit producer/consumer pairing on the QK path** (`compute_qk` / `compute_p`): + +```361:381:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + if (sub_tile_id == 0 && should_wait_consume) + qk2smSync.allocate(); // wait for SM consume data + // ... + TSTORE(qkGlobalTile, qkAccTile); + // ... + if (sub_tile_id == static_cast(kTileFactor) - 1) + qk2smSync.record(); // notify for QK produce data +``` + +```496:520:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + wait_flag(PIPE_V, PIPE_MTE2, pTileEventId); + if (row_slice == 0) + qk2smSync.wait(); // wait for QK produce data + // ... TLOAD kTileFactor strips ... + if (row_slice == static_cast(kTileFactor) - 1 && should_notify_consume) + qk2smSync.free(); // notify for SM consume data +``` + +**Sync object type** (template parameter baked into `runTFA`): + +```817:817:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + constexpr TSync_Custom qk2smSync = {BUF0_QK_READY}; +``` + +**Python — no named `TSync_Custom`; ordering relies on pipe ops + toolchain-inserted sync** (`--enable-insert-sync` in `experimental/compile.sh`): + +```279:281:examples/aot/flash_attention/experimental/fa_builder.py + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, +``` + +**Ask for PTOAS / docs:** a small **equivalence table** row per ref hook, e.g. **`qk2smSync.record()` after last `TSTORE` of a tile** ↔ **which `tpush` / GM completion barrier** in generated C++ after `ptoas` version *X*, so performance regressions can be bisected without reading all inserted barriers. + +### G. Row / column indexing: **`row_slice` + `Vec_S0`** (C++) vs **`get_subblock_idx` + `S0_HALF`** (Python) + +**Reference — software row origin per vec core and `row_slice`** (`compute_p`): + +```488:492:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + const size_t subblock_base_rows = + static_cast(Cube_S0 / VEC_CORES) * static_cast(get_subblockid()); + const size_t row_offset = subblock_base_rows + static_cast(row_slice * Vec_S0); + const int s0_index = blk_idx * Cube_S0 + row_offset; +``` + +**Python — subblock row origin** (half of **`Cube_S0`** per AIV sub-block for **`TILE_UP_DOWN`**): + +```540:541:examples/aot/flash_attention/experimental/fa_builder.py + sb_idx = s.index_cast(pto.get_subblock_idx()) + row_off_sb = sb_idx * cS0_HALF +``` + +There is **no** Python **`for row_slice in range(kTileFactor):`** around softmax; **`row_slice × Vec_S0`** from the ref is only partially reflected via **`S0_HALF`** recv tiles + pipe **`split`**. The **missing documentation** is a **direct mapping table** (`row_slice`, `Vec_S0`, `kTileFactor`) ↔ (`TILE_UP_DOWN`, `S0_HALF`, `pto.tpop` tile types), so authors do not treat **`split=1`** as a complete substitute for **ref softmax tile decomposition** without proof. ---