diff --git a/examples/aot/flash_attention/experimental/fa_builder.py b/examples/aot/flash_attention/experimental/fa_builder.py index 3b178b3e..e398cb9b 100644 --- a/examples/aot/flash_attention/experimental/fa_builder.py +++ b/examples/aot/flash_attention/experimental/fa_builder.py @@ -4,20 +4,22 @@ # # This file mirrors the reference C++ scheduler: # -# constexpr int qkPreloadNum = 2; // warmup depth +# constexpr int qkPreloadNum = 2; // warmup depth (reference uses 4; UB-limited here) # -# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec consumes them and -# pushes P[0..QK_PRELOAD-1]. No PV / gu yet. */ +# Cube K tiling matches ref `Cube_S1` / `kTileFactor`: two matmuls per logical +# `S1_TILE` tile into `qk_acc` column subviews; K MAT uses ping-pong buffers +# (`kMatTNBuffers = 2` in ref). +# +# PV stays one `p_left × v_right` matmul: `tmatmul` requires **LEFT** lhs, but +# boxed RowMajor **LEFT** forbids P column `tile.subview` (`keep full cols`); +# MAT P allows column subviews yet cannot be `tmatmul` lhs — ref `compute_pv` +# K-strip schedule needs a future layout/op story in PTO DSL. +# +# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec softmaxes them. */ # # /* Steady state, tile_id 0..N-1: -# cube: if (t+QK_PRELOAD < N) compute_qk(t+QK_PRELOAD); -# compute_pv(tile_id); -# vec: if (t+QK_PRELOAD < N) compute_p(t+QK_PRELOAD); -# compute_gu(tile_id); -# so vec's softmax for the LOOK-AHEAD tile fills the QK consumption -# slot WHILE the cube is computing the current PV[t]. The cube -# stops being blocked on a freshly-pushed P (softmax of t+2 has -# already pushed P[t+2] into the FIFO by the time cube needs it). */ +# cube: compute QK[t+QK_PRELOAD] while draining PV[t]; +# vec: interleave gu(t) with softmax(t+QK_PRELOAD) using exp_max a/b. */ # # /* Epilogue: drain the last QK_PRELOAD tiles' PV / gu. */ # @@ -27,6 +29,8 @@ # need a 2-deep ring of `exp_max` tiles (`exp_max_a`, `exp_max_b`). We # implement the ring by unrolling the steady-state loop in pairs of 2 # iterations: even iters use `exp_max_a`, odd iters use `exp_max_b`. +# (Raising preload to 4 like the C++ launch requires a 4-slot ring and more +# VEC UB headroom; a prototype hit CCU address faults until layout is reworked.) # # Other state in the softmax (`new_global_max`, `new_global_sum`) does # NOT need a ring: it is monotonic accumulator state across all tiles @@ -49,10 +53,17 @@ # --------------------------------------------------------------------------- # Static shapes (must match run.py constants) # --------------------------------------------------------------------------- -S0 = 32 # Q rows per block -S0_HALF = S0 // 2 # rows per AIV sub-block +# Match reference `CUBE_S0 = 128` (`fa_kernel.cpp`). Vec UB grows with +# `S0` because the QK/PV pipe mirrors full tiles in UB; override with FA_S0 +# if a smaller block is needed while tooling catches up (see `known_gap.md`). +S0 = int(os.environ.get("FA_S0", "128")) +S0_HALF = S0 // 2 # rows per AIV sub-block (TILE_UP_DOWN split) HEAD = 128 # attention head dimension -S1_TILE = 256 # K/V columns per tile +S1_TILE = 256 # K/V columns per tile (logical `Tile_S1`; ref `fa_kernel.cpp`) +# Reference: `Cube_S1` K micro-tile and `kTileFactor = Tile_S1 / Cube_S1` matmul passes. +CUBE_S1 = int(os.environ.get("FA_CUBE_S1", "128")) +assert S1_TILE % CUBE_S1 == 0, "S1_TILE must be divisible by FA_CUBE_S1" +K_TILE_FACTOR = S1_TILE // CUBE_S1 # NUM_TILES is overridable via the FA_NUM_TILES env var so the same builder # can produce kernels for different sequence lengths # (S1_TOTAL = S1_TILE * NUM_TILES). @@ -62,13 +73,11 @@ S1_TOTAL = S1_TILE * NUM_TILES Q_ROWS = 2048 -NUM_Q_BLOCKS = Q_ROWS // S0 # 64 row-blocks +NUM_Q_BLOCKS = Q_ROWS // S0 # e.g. 16 when Q_ROWS=2048 and S0=128 -# QK preload depth — must be >= 1; reference uses 2. The vec pre-softmaxes -# tiles 0..QK_PRELOAD-1, then the steady-state loop interleaves softmax(t+QK_PRELOAD) -# with gu(t), and the epilogue drains the last QK_PRELOAD gu's. -# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled to -# ping-pong the exp_max ring (see below). +# QK preload depth — must be >= 1; reference launch uses 4, this builder +# keeps 2 for a smaller VEC exp_max ring (see header comment). +# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled. QK_PRELOAD = 2 assert ( NUM_TILES - QK_PRELOAD @@ -76,7 +85,8 @@ STEADY_PAIRS = (NUM_TILES - QK_PRELOAD) // 2 # Per-pipe slot sizes (bytes). -SLOT_SIZE_QK = S0 * S1_TILE * 4 # fp32 QK accumulator +# QK: fp32 on the wire (cube `tpush` only accepts ACC tiles today — see known_gap). +SLOT_SIZE_QK = S0 * S1_TILE * 4 SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator SLOT_SIZE_P = S0 * S1_TILE * 2 # fp16 P matrix sent vec → cube @@ -103,10 +113,15 @@ # Explicit local-memory layout used when compiling with --pto-level=level3. # Offsets are byte offsets within each independent local address space. MAT_Q_OFF = 0 -MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2 -MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2 +# Two physical K tiles (ref `kMatTNBuffers = 2`) for ping-pong loads. +MAT_K0_OFF = MAT_Q_OFF + S0 * HEAD * 2 +MAT_K1_OFF = MAT_K0_OFF + HEAD * CUBE_S1 * 2 +MAT_P_RECV_OFF = MAT_K1_OFF + HEAD * CUBE_S1 * 2 MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2 -MAT_P_FIFO_OFF = 262144 +MAT_P_FIFO_OFF = MAT_V_OFF + S1_TILE * HEAD * 2 +# Pad past the last MAT-resident tile; bisheng is sensitive to overlap here. +if MAT_P_FIFO_OFF < 393216: + MAT_P_FIFO_OFF = 393216 LEFT_Q_OFF = 0 LEFT_P_OFF = LEFT_Q_OFF + S0 * HEAD * 2 @@ -123,13 +138,16 @@ VEC_P_FP16_OFF = VEC_P_FP32_OFF + S0_HALF * S1_TILE * 4 VEC_O_OFF = VEC_P_FP16_OFF + S0_HALF * S1_TILE * 2 VEC_RED_BASE_OFF = VEC_O_OFF + S0_HALF * HEAD * 4 -VEC_RED_STRIDE = 512 +# Tight packing for reduce / exp_max ring scalars (one column per logical row). +VEC_RED_STRIDE = ((S0_HALF * 4 + 127) // 128) * 128 VEC_NEW_GLOBAL_MAX_OFF = VEC_RED_BASE_OFF + 0 * VEC_RED_STRIDE VEC_LOCAL_MAX_OFF = VEC_RED_BASE_OFF + 1 * VEC_RED_STRIDE VEC_NEW_GLOBAL_SUM_OFF = VEC_RED_BASE_OFF + 2 * VEC_RED_STRIDE VEC_LOCAL_SUM_OFF = VEC_RED_BASE_OFF + 3 * VEC_RED_STRIDE VEC_EXP_MAX_A_OFF = VEC_RED_BASE_OFF + 4 * VEC_RED_STRIDE VEC_EXP_MAX_B_OFF = VEC_RED_BASE_OFF + 5 * VEC_RED_STRIDE +# Shared recv scratch: max(fp32 QK half-tile, fp32 PV half-tile) for tpop addr=. +_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 4, S0_HALF * HEAD * 4) VEC_RECV_OFF = VEC_RED_BASE_OFF + 6 * VEC_RED_STRIDE ID_QK = 10 # Cube → Vec, dir_mask = 1 (uses lower-level l2g2l) @@ -155,6 +173,7 @@ def meta_data(): q_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp16) kt_sub_ty = pto.SubTensorType(shape=[HEAD, S1_TILE], dtype=fp16) + kt_sub_slice_ty = pto.SubTensorType(shape=[HEAD, CUBE_S1], dtype=fp16) v_sub_ty = pto.SubTensorType(shape=[S1_TILE, HEAD], dtype=fp16) o_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp32) o_sub_half_ty = pto.SubTensorType(shape=[S0_HALF, HEAD], dtype=fp32) @@ -163,13 +182,13 @@ def meta_data(): q_mat_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="MAT") q_left_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="LEFT") k_mat_ty = pto.TileBufType( - shape=[HEAD, S1_TILE], + shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="MAT", config=pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor"), ) k_right_ty = pto.TileBufType( - shape=[HEAD, S1_TILE], dtype=fp16, memory_space="RIGHT" + shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="RIGHT" ) qk_acc_ty = pto.TileBufType(shape=[S0, S1_TILE], dtype=fp32, memory_space="ACC") p_recv_ty = pto.TileBufType( @@ -229,11 +248,11 @@ def cube_kernel( ) -> None: c0 = const(0) c1 = const(1) - c2 = const(2) cS0 = const(S0) cHEAD = const(HEAD) cS1_TILE = const(S1_TILE) cS1_TOTAL = const(S1_TOTAL) + cCUBE_S1 = const(CUBE_S1) cNUM_TILES = const(NUM_TILES) cNUM_Q_BLOCKS = const(NUM_Q_BLOCKS) @@ -297,15 +316,13 @@ def cube_kernel( nosplit=False, ) - # All cube tile-buffers are single-buffered. K and V share RIGHT - # storage: for HEAD=128, S1_TILE=256 each RIGHT tile is exactly - # 64 KB, and the schedule uses V for PV before moving K for QK. - # This mirrors the hand-written reference's explicit local-memory - # assignment style and avoids asking RIGHT for two full tiles. + # K and V share one RIGHT bank (sequential use); `[buf]` lists alias + # to the same physical tile for each role (schedule is still safe). right_base = const(RIGHT_KV_OFF, s.int64) q_mat = pto.alloc_tile(q_mat_ty, addr=const(MAT_Q_OFF, s.int64)) q_left = pto.alloc_tile(q_left_ty, addr=const(LEFT_Q_OFF, s.int64)) - k_mat_s = pto.alloc_tile(k_mat_ty, addr=const(MAT_K_OFF, s.int64)) + k_mat_0 = pto.alloc_tile(k_mat_ty, addr=const(MAT_K0_OFF, s.int64)) + k_mat_1 = pto.alloc_tile(k_mat_ty, addr=const(MAT_K1_OFF, s.int64)) k_right_s = pto.alloc_tile(k_right_ty, addr=right_base) qk_acc_s = pto.alloc_tile(qk_acc_ty, addr=const(ACC_QK_OFF, s.int64)) p_recv_s = pto.alloc_tile(p_recv_ty, addr=const(MAT_P_RECV_OFF, s.int64)) @@ -313,9 +330,7 @@ def cube_kernel( v_mat_s = pto.alloc_tile(v_mat_ty, addr=const(MAT_V_OFF, s.int64)) v_right_s = pto.alloc_tile(v_right_ty, addr=right_base) pv_acc_s = pto.alloc_tile(pv_acc_ty, addr=const(ACC_PV_OFF, s.int64)) - # Aliasing wrappers: keep the per-iteration `[buf]` indexing pattern - # in the body even though all slots currently point at one alloc. - k_mat = [k_mat_s, k_mat_s] + k_mat = [k_mat_0, k_mat_1] k_right = [k_right_s, k_right_s] qk_acc = [qk_acc_s, qk_acc_s] p_recv = [p_recv_s, p_recv_s] @@ -341,6 +356,25 @@ def cube_kernel( strides=[cHEAD, c1], ) + # Two `Cube_S1`-wide matmuls per logical tile (ref `compute_qk` / `kTileFactor`). + def accumulate_qk_for_tile(k_s1_offset, qk_buf, k_mat_buf_idx): + for sc in range(K_TILE_FACTOR): + k_col = k_s1_offset + const(sc * CUBE_S1) + kt_view = pto.slice_view( + kt_sub_slice_ty, + source=tv_k, + offsets=[c0, k_col], + sizes=[cHEAD, cCUBE_S1], + ) + pto.load(kt_view, k_mat[k_mat_buf_idx]) + tile.mov(k_mat[k_mat_buf_idx], k_right[k_mat_buf_idx]) + qk_part = tile.subview( + qk_acc[qk_buf], + [c0, const(sc * CUBE_S1)], + [S0, CUBE_S1], + ) + tile.matmul(q_left, k_right[k_mat_buf_idx], qk_part) + for qb in pto.range(qb_start, qb_end, c1): q_row_off = qb * cS0 @@ -351,19 +385,9 @@ def cube_kernel( tile.mov(q_mat, q_left) # =================== Cube prologue: emit QK[0..QK_PRELOAD-1] =================== - # Each prologue QK uses its own k_mat / k_right / qk_acc slot - # so MTE2 load of K[1] overlaps the M of QK[0]. for k in range(QK_PRELOAD): - k_off = const(k * S1_TILE) - kt_view_k = pto.slice_view( - kt_sub_ty, - source=tv_k, - offsets=[c0, k_off], - sizes=[cHEAD, cS1_TILE], - ) - pto.load(kt_view_k, k_mat[k]) - tile.mov(k_mat[k], k_right[k]) - tile.matmul(q_left, k_right[k], qk_acc[k]) + k_s1_off = const(k * S1_TILE) + accumulate_qk_for_tile(k_s1_off, k, k % 2) pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) # Preload V[0] for the very first PV. @@ -376,25 +400,10 @@ def cube_kernel( pto.load(v_view_0, v_mat[0]) # =================== Cube steady state =================== - # Pair-unrolled. Iter t (parity = t%2 → buffer index `b`): - # * load K[next_qk = t+QK_PRELOAD] into k_mat[b] - # (next_qk parity equals t parity since QK_PRELOAD == 2) - # * pop / mov P[t] into p_left[b]; mov V[t] (in v_mat[b]) → v_right[b] - # * preload V[t+1] into v_mat[1-b] - # * matmul PV[t] into pv_acc[b]; push - # * matmul QK[next_qk] into qk_acc[b]; push - # Pair handler: + # Pair-unrolled; buffer index b = t % 2 (logical ping-pong). def emit_cube_step(t_idx, b): - # next_qk = t_idx + QK_PRELOAD (only used when in main range) next_qk = t_idx + const(QK_PRELOAD) - kt_off = next_qk * cS1_TILE - kt_view = pto.slice_view( - kt_sub_ty, - source=tv_k, - offsets=[c0, kt_off], - sizes=[cHEAD, cS1_TILE], - ) - pto.load(kt_view, k_mat[b]) + k_s1_off = next_qk * cS1_TILE p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) tile.mov(p_raw, p_left[b]) @@ -413,23 +422,20 @@ def emit_cube_step(t_idx, b): tile.matmul(p_left[b], v_right[b], pv_acc[b]) pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN) - tile.mov(k_mat[b], k_right[b]) - tile.matmul(q_left, k_right[b], qk_acc[b]) + accumulate_qk_for_tile(k_s1_off, b, b) pto.tpush(qk_acc[b], qk_pipe, SPLIT_UP_DOWN) assert (NUM_TILES - QK_PRELOAD) % 2 == 0 - for p in pto.range(c0, const((NUM_TILES - QK_PRELOAD) // 2), c1): + c2 = const(2) + for p in pto.range(c0, const(STEADY_PAIRS), c1): t_a = p * c2 emit_cube_step(t_a, 0) t_b = p * c2 + c1 emit_cube_step(t_b, 1) # =================== Cube epilogue: drain last QK_PRELOAD PVs =================== - # Tile_id range: NUM_TILES-QK_PRELOAD .. NUM_TILES-1. - # NUM_TILES is even and QK_PRELOAD is even, so the first epilogue - # tile has parity 0. v_mat[0] holds V[NUM_TILES-QK_PRELOAD] thanks - # to the last steady-state preload (it loaded V[t_b+1] = V[NUM_TILES-QK_PRELOAD] - # into v_mat[1-1]=v_mat[0]). + # Tile ids NUM_TILES-QK_PRELOAD .. NUM_TILES-1; v_mat parity matches + # the last steady-state V preload pattern (same as QK_PRELOAD==2 case). for k in range(QK_PRELOAD): b = k % 2 p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) @@ -461,7 +467,6 @@ def vector_kernel( ) -> None: c0 = const(0) c1 = const(1) - c2 = const(2) cS0 = const(S0) cS0_HALF = const(S0_HALF) cHEAD = const(HEAD) @@ -547,13 +552,7 @@ def vector_kernel( red_ty, addr=const(VEC_NEW_GLOBAL_SUM_OFF, s.int64) ) local_sum = pto.alloc_tile(red_ty, addr=const(VEC_LOCAL_SUM_OFF, s.int64)) - # Ring of QK_PRELOAD exp_max tiles. With QK_PRELOAD=2 we use a/b - # ping-pong: even-parity tiles use exp_max_a, odd-parity tiles use - # exp_max_b. softmax(t) writes the exp_max for tile t into the - # corresponding slot; gu(t) reads it from the same slot. Because - # softmax(t+QK_PRELOAD) and gu(t) hit the SAME slot (parity matches), - # the steady-state loop must do gu(t) BEFORE softmax(t+QK_PRELOAD) - # to avoid clobbering. + # Two exp_max tiles (a/b). Interleave gu and softmax to avoid clobber. assert QK_PRELOAD == 2, "exp_max ring is hard-coded to 2 tiles" exp_max_a = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_A_OFF, s.int64)) exp_max_b = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_B_OFF, s.int64)) @@ -577,7 +576,7 @@ def emit_softmax_step(exp_max_slot, is_init): addr=const(VEC_RECV_OFF, s.int64), ) tile.muls(qk_recv, scale, qk_recv) - tile.row_max(qk_recv, tmp_tile, local_max) + tile.row_max(qk_recv, p_fp32, local_max) local_max_r = tile.reshape(red_row_ty, local_max) new_global_max_r = tile.reshape(red_row_ty, new_global_max) @@ -625,29 +624,15 @@ def emit_gu_step(exp_max_slot, is_init): o_row_off = qb * cS0 # =================== Vec prologue: softmax(0..QK_PRELOAD-1) =================== - # softmax(0): is_init=True (writes exp_max_a, but exp_max_a for tile 0 - # is unused by gu(0) — gu(0) takes the init branch and just movs PV. - # Still we must compute it correctly; the init branch doesn't touch exp_max. emit_softmax_step(exp_max_a, is_init=True) - # softmax(1): is_init=False (writes exp_max_b) emit_softmax_step(exp_max_b, is_init=False) # =================== Vec steady state =================== - # Pair-unrolled: each `p` iteration handles tiles t_a = 2p, t_b = 2p+1. - # gu(t_a) reads exp_max_a (set by softmax(t_a) earlier); - # softmax(t_a+2) writes exp_max_a (matches parity). - # gu(t_b) reads exp_max_b; softmax(t_b+2) writes exp_max_b. - # CRITICAL: gu BEFORE softmax in same step to avoid clobbering. - # - # First pair (p=0, t_a=0, t_b=1) is Python-unrolled so we can - # take the `is_init=True` branch on gu(0) (which initializes - # o_tile via mov rather than rescale+add). - emit_gu_step(exp_max_a, is_init=True) # tile 0 - emit_softmax_step(exp_max_a, is_init=False) # tile 2 → exp_max_a - emit_gu_step(exp_max_b, is_init=False) # tile 1 - emit_softmax_step(exp_max_b, is_init=False) # tile 3 → exp_max_b - - # Remaining pairs (p=1..STEADY_PAIRS-1) inside a runtime loop. + emit_gu_step(exp_max_a, is_init=True) + emit_softmax_step(exp_max_a, is_init=False) + emit_gu_step(exp_max_b, is_init=False) + emit_softmax_step(exp_max_b, is_init=False) + for p in pto.range(c1, const(STEADY_PAIRS), c1): emit_gu_step(exp_max_a, is_init=False) emit_softmax_step(exp_max_a, is_init=False) diff --git a/examples/aot/flash_attention/known_gap.md b/examples/aot/flash_attention/known_gap.md new file mode 100644 index 00000000..4e60b885 --- /dev/null +++ b/examples/aot/flash_attention/known_gap.md @@ -0,0 +1,115 @@ +# Known gap: PTO Python DSL flash attention vs reference C++ + +This document compares the AOT flash-attention builders (`fa_builder.py`, `experimental/fa_builder.py`) and their `ptoas` output to the hand-written reference in `cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA` and helpers). It updates earlier notes on **macro parity** and **`--enable-insert-sync`**. + +## Revised summary + +### What the DSL already mirrors + +- High-level **software pipeline**: QK preload, steady-state interleaving of softmax (lookahead) with GU (current tile), and an **`exp_max` ping-pong ring** when `QK_PRELOAD == 2` (experimental) to match the reference’s “softmax ahead of GU” hazard story. +- **Multi-pipe** QK (cube→vec), P (vec→cube), PV (cube→vec) with GM-backed slots, analogous to the reference’s FIFO staging (different mechanism, same role). + +### Primary performance gaps (largest expected impact) + +1. **Cube tiling and S1 sub-tiling (QK vs PV)** + Reference: `CUBE_S0 = 128`, `CUBE_S1 = 128`, `TILE_S1 = 256`, **`kTileFactor = 2`** on **both** **`compute_qk`** and **`compute_pv`** with **`AccMode`** across subtiles. + DSL (experimental): **`S0 = 128`** (env `FA_S0`), **`FA_CUBE_S1`** default **128** — **QK** now runs **`kTileFactor`** matmuls into **`qk_acc`** column **`tile.subview`s** (matches ref **`compute_qk`** geometry). **PV** is still **one** full **`S1_TILE`**-wide matmul per step (no ref-style **`compute_pv`** K-striping yet). Remaining throughput gap vs fused ref is mostly **PV / buffer depth / preload**, not “missing QK K-split”. + +2. **Real L1 double-buffering** + Reference: `kMatTNBuffers = 2`, `pMatTNBuffers = 2`, `vMatTNBuffers = 2`, plus vec-side ping-pong (`srcVecTNBuffers`, `xexpVecTNBuffers`, `outOTileNBuffers`). + DSL: cube **single-buffered** tiles with **aliased** `[k_mat_s, k_mat_s]` / same for P/V; **`QK_LOCAL_SLOT_NUM = 1`** on the QK pipe because deeper local slots overflow vec UB. That limits overlap of **TLOAD** with **TMATMUL** compared to the reference. + +3. **Preload depth** + Reference launch uses **`QK_PRELOAD = 4`** (`fa_kernel.cpp`). Experimental DSL uses **`QK_PRELOAD = 2`**. Shallower preload reduces cube/vec overlap on long S1. A DSL port to 4 needs a **4-deep `exp_max` ring** and more VEC space; an attempt faulted until UB layout (recv scratch, `MAT_P_FIFO` tail) is redesigned. + +### Macro parity (item 5) — port to Python DSL, not “optional tuning” + +The reference’s hot path is not arbitrary `tile.*` soup; it goes through shared headers that should be **replicated in Python DSL** so lowering and scheduling stay aligned with the tuned C++ path: + +| Reference include | Role | +|-------------------|------| +| [`pto_macro_matmul.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_matmul.hpp) | Cube matmul with **`AccMode`** (e.g. `InitFinalSum` under `UF_ENABLE`), K tiling, and L0-oriented constraints. | +| [`pto_macro_fa_softmax.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_softmax.hpp) | Streaming softmax: `softmax_opt_fa_init_impl` / `softmax_opt_fa_not_init_impl`, scale = `1/sqrt(HEAD)`, **TROWMAX**, **TROWEXPANDSUB**, **TEXP**, **TROWSUM**, and (inside the macro) the sequence that feeds **P**—including **TCVT** where the **macro** emits half for the V2C pipe. **QK** in the reference kernel is still **fp32 in GM** via **`TSTORE`** and **`TLOAD`** in `compute_qk` / `compute_p`; do not conflate macro-internal P conversion with a separate invented “fp16 QK wire”. | +| [`pto_macro_fa_gu.hpp`](../../../../pto-isa-master/kernels/manual/common/flash_atten/pto_macro_fa_gu.hpp) | **pto_macro_fa_gu** (`TROWEXPANDMUL` + `TADD`), **pto_macro_fa_gu_last** (+ `TROWEXPANDDIV` by `new_global_sum`), **pto_macro_fa_gu_single_and_last_tile**. | + +**Goal:** Express the same ordered primitive sequence (and the same init vs non-init / last-tile branching) in `ptodsl` `tile.*` / `pto.*` APIs—or add thin DSL helpers that document a 1:1 mapping to those macros—so the compiler stack can match the reference kernel’s numerics and fusion expectations. Today’s DSL uses composable `tile.row_max`, `tile.exp`, etc.; they must be **audited and aligned** macro-step by macro-step, not assumed equivalent. + +### `ptoas --enable-insert-sync` (item 6) + +**`--enable-insert-sync` is intentional:** it simplifies generated C++ by having the toolchain insert synchronization. Impact on performance is treated as **minor** relative to **tiling, real double-buffering, and preload depth**. + +Closing the gap should **not** rely on turning sync insertion off; it should rely on **geometry + buffering + macro-faithful lowering** (and any future DSL-level sync refinement if needed). + +### Secondary / structural differences + +- **GM / FFTS / CV:** Reference uses FFTS base, `TSync_Custom`, optional CV comm for many blocks; DSL uses `l2g2l_pipe` / `aic_initialize_pipe` / `aiv_initialize_pipe` and GM slot buffers. Functionally similar staging; details may diverge under high block counts. +- **Benchmark shape parity:** e.g. `experimental/run.py` uses `Q_ROWS` from the builder vs `naive_tpush/run.py` using `s0 = 128*24`; compare throughput with **matched** `Q_ROWS`, `HEAD`, `S1`, and tile counts when isolating kernel quality. + +--- + +## Progress log (experimental `fa_builder.py`, Apr 2026) + +Measured on NPU via `experimental/run.py` (Q=2048, H=128, S1_TILE=256): with QK **`kTileFactor=2`**, **~25 TFLOP/s** at S1=8192 (Apr 2026; run-to-run variance ±~1 TFLOP/s) vs **~61 TFLOP/s** for `torch_npu` fused ref on the same script; correctness (`assert_close`) remains the gate. + +| Change attempted | Result | +|------------------|--------| +| `QK_PRELOAD=4` + four `exp_max` slots + quad-unrolled vec/cube + true dual MAT banks for K/P/V | AICore CCU address fault (`mte`/`ccu`); likely VEC `tpop` scratch / expanded red region + dual `RIGHT` typing; **reverted**. | +| True L1 ping-pong (separate `MAT_K0`/`MAT_K1`, …) without preload-4 | Overlapped `MAT_P_FIFO` with tail tiles until `MAT_P_FIFO_OFF` was recomputed; still faulted with dual `RIGHT` `alloc_tile` until fully reverted to aliased single-buffer layout. | +| QK K-split: two `CUBE_S1=128` matmuls per tile via `tile.subview` on `qk_acc` | **Landed** for ref **`compute_qk`** parity; `tile.subview` **`sizes`** must be Python **`int`**, not `s.const` (MLIR `I64ArrayAttr`). **PV** ref-style **`kTileFactor`** is **blocked**: **`tmatmul` lhs must be LEFT**, but LEFT P cannot column-**`subview`**; MAT P can **`subview`** but cannot be **`tmatmul` lhs** (see `ptoas_request.md`). | +| Reorder steady cube step (PV before K load) | **Slight regression** vs original order on sampled runs — **reverted**. | + +**Takeaway:** With **`S0=128`** and **QK `kTileFactor`** aligned to the reference, the largest remaining structural gaps versus the reference are **`compute_pv` K-split / `AccMode`**, **preload / ring depth (`QK_PRELOAD`, CV FIFO)**, **true L1 ping-pong**, and **vec working-tile geometry (`Vec_S0`)** relative to what `l2g2l_pipe` + `TILE_UP_DOWN` imply for UB. Raising `QK_PRELOAD` still needs more `exp_max` slots and recv/GM budget. + +| `S0=128` + QK `kTileFactor` (Apr 2026) | Default `S0` **128** (`FA_S0`), **`FA_CUBE_S1`** default **128**, QK via **`tile.subview`** acc columns. Cube **`pto.tpush`** uses **`ACC`** into `l2g2l_pipe`; ref uses **`TSTORE`** to fp32 GM—**`PIPE_UNASSIGNED` for MAT/LEFT `tpush` is expected**. **`compile.sh`** + **`run.py`** pass on NPU; **~25 TFLOP/s** at 8k vs fused ref **~61 TFLOP/s**. | + +--- + +## PTOAS / PTO dialect / Python binding — reasonable asks (see `ptoas_request.md`) + +Feature requests must **mirror what `fa_kernel.cpp` already does**, not invent paths the reference does not use (e.g. **no** MAT/LEFT→GM for QK, **no** extra **`tile.cvt`** pipeline on QK beyond what the **macros** imply for **P**). FP32 QK in the ref is **`TSTORE`/`TLOAD`** to/from **fp32 GM**; **`ptoas --enable-insert-sync`** remains the **intentional** allowed divergence from hand-placed `TSync_Custom`. + +A maintained list of **documentation-first** and **parity** asks (GM packing vs `l2g2l_pipe`, `kTileFactor`/`Vec_S0`, sync equivalence, ergonomics) lives in **`examples/aot/flash_attention/ptoas_request.md`**. + +--- + +## TODO: close the gap (Python DSL ↔ reference C++) + +Use this as a work backlog; order roughly reflects suggested priority (tiling/buffers first, then macro fidelity, then integration). + +### Tiling and cube schedule + +- [x] **QK: match reference `kTileFactor` loop** — two **`Cube_S1`** K slices per logical tile into **`qk_acc`** (`experimental/fa_builder.py`, env **`FA_CUBE_S1`**). +- [ ] **PV: match reference `compute_pv`** — **`kTileFactor`** partial matmuls / **`AccMode`**. Blocked in Python today: **`tmatmul` requires LEFT lhs**; LEFT P rejects column **`tile.subview`**; MAT P allows **`subview`** but is not a legal **`tmatmul` lhs** (see **`ptoas_request.md`** §2). +- [x] **Match reference `CUBE_S0` (128) in experimental builder** — default `S0=128` via `FA_S0` (Apr 2026); UB layout was tightened (shared `tpop` recv sizing, `row_max` scratch reuse, smaller `VEC_RED_STRIDE`). Smaller blocks remain available with `FA_S0=32` etc. if needed. +- [ ] **Align `QK_PRELOAD`** with the reference launch (**4**) and extend the **`exp_max` / GU ring** logic (or equivalent hazard avoidance) for that depth; assert fifo and UB sizing. + +### Double-buffering and overlap + +- [ ] **Implement true L1 ping-pong** for **K**, **P**, and **V** cube tiles (separate physical buffers, not aliased `[x, x]`). +- [ ] **Vec UB layout:** budget space for **`QK_LOCAL_SLOT_NUM > 1`** if required to mirror reference QK pipe depth, without exceeding UB limits; coordinate with `SLOT_NUM` / GM stride math. +- [ ] **Vec tile banks:** mirror **`srcVecTNBuffers=2`**, **`xexpVecTNBuffers=2`**, **`outOTileNBuffers=2`** (or document why a smaller depth is equivalent). + +### Macro parity in Python (lowering contract) + +- [ ] **Matmul:** Port or wrap **`pto_macro_matmul`** semantics in DSL—especially **`AccMode`** (`Init` / `InitFinalSum` / partial vs final slices) and the **K-subslice** interaction with double-buffered K tiles. +- [ ] **Softmax:** Port **`softmax_opt_fa_init_impl`** and **`softmax_opt_fa_not_init_impl`** (and causal paths if needed) as an explicit sequence of DSL ops matching the **macro** ordering in `pto_macro_fa_softmax.hpp` (including whatever the macro uses to produce **half** for **P**—e.g. **`TCVT`** inside the macro—not a separate invented QK cast). +- [ ] **GU:** Port **`pto_macro_fa_gu`**, **`pto_macro_fa_gu_last`**, and **`pto_macro_fa_gu_single_and_last_tile`** as DSL sequences matching **TROWEXPANDMUL / TADD / TROWEXPANDDIV** usage in `pto_macro_fa_gu.hpp`. +- [ ] **Numerics test:** Keep **`torch.testing.assert_close`** (same `rtol`/`atol` as `run.py` / `experimental/run.py`) as the gate after each macro block port. + +### Toolchain and integration + +- [ ] **Keep `--enable-insert-sync`** in `compile.sh` / `experimental/compile.sh`; optimize kernel structure first; only revisit sync policy if profiling shows it dominates after tiling/buffer parity. +- [ ] **Optional:** Parity for **FFTS / CV / `TSync_*`** paths if multi-block or multi-core scaling diverges from reference after cube/vec parity work. +- [ ] **Docs:** When a builder variant reaches parity, record fixed constants (`CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, buffer counts) next to `jit_util_flash.py` / `fa_kernel.cpp` launch parameters so drift is obvious in review. + +--- + +## File map + +| Artifact | Path | +|----------|------| +| Reference kernel | `cpp_ref/naive_tpush/fa_kernel.cpp` | +| JIT constants | `cpp_ref/naive_tpush/jit_util_flash.py` | +| Non-experimental builder | `fa_builder.py` | +| Experimental builder | `experimental/fa_builder.py` | +| PTO FA macros (ISA tree) | `pto-isa-master/kernels/manual/common/flash_atten/pto_macro_*.hpp` | diff --git a/examples/aot/flash_attention/ptoas_request.md b/examples/aot/flash_attention/ptoas_request.md new file mode 100644 index 00000000..d19dbc52 --- /dev/null +++ b/examples/aot/flash_attention/ptoas_request.md @@ -0,0 +1,407 @@ +# PTOAS feature requests (PTO MLIR dialect + Python bindings) + +This document lists **reasonable** asks for PTOAS / the PTO dialect so Python FA (`experimental/fa_builder.py` via `ptodsl`) can **mirror** `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` (`runTFA`, `compute_qk`, `compute_p`, `compute_pv`, `compute_gu`). + +**Ground rules (read first)** + +1. **No invented algorithms.** PTOAS should support a **1:1** mapping to the reference’s data path and control flow, not new “more clever” patterns the C++ kernel does not use. +2. **QK path in the reference is not `TCvt` on tiles then `TPUSH` from MAT/LEFT.** Cube writes QK to GM with **`TSTORE`** from the **accumulator** (`compute_qk`); vec reads with **`TLOAD`** into vec UB (`compute_p`). There is **no** MAT/LEFT→GM push for QK. Where fp16 appears for **P**, the reference uses the **V2C `TPipe`** / macro path (e.g. `sizeof(half)` slot size in `runTFA`), not a fabricated “fp16 QK wire”. +3. **`TPushOp` accepts ACC (and VEC for the other direction) by design** (`include/PTO/IR/PTOOps.td`). **`PIPE_UNASSIGNED` for MAT/LEFT as `tpush` sources is expected**: there is no direct MAT/LEFT→global path like `TSTORE`. That is **not** a bug to “fix” by widening `tpush` to MAT/LEFT for FA—doing so would **diverge** from the reference. +4. **Allowed toolchain divergence** from hand-written C++: **`ptoas --enable-insert-sync`** (and similar passes) that insert synchronization; everything else should aim at **reference parity**, not new semantics. + +Upstream **PTOAS** paths below are **relative to the `PTOAS/` repository root**. + +--- + +## 1. Documentation: reference **QK** path ↔ **`l2g2l_pipe` + ACC `tpush` / `tpop`** + +**What the reference does.** Cube **`TSTORE`** of **`float`** tiles shaped **`Cube_S0 × Cube_S1`** into `qk_tile_fifo` with the `base_elems` packing (`fa_kernel.cpp`, `compute_qk`). Vec **`TLOAD`**s **`Vec_S0 × Cube_S1`** slices and assembles **`Vec_S0 × Tile_S1`** in UB (`compute_p`, loop over `sub_col`). + +**What the Python builder does today.** Uses **`pto.tpush(qk_acc, qk_pipe)`** with **`slot_size = S0 * S1_TILE * sizeof(fp32)`** and GM backing—**analogous** to getting QK to GM + vec visibility, but **not** identical to the reference’s per-`sub_tile_id` **`Cube_S0×Cube_S1`** store layout when `kTileFactor > 1`. + +**Ask (PTOAS / docs only).** + +- In `docs/designs/ptoas-tpush-tpop-design.md` (or a small FA note), add a **side-by-side**: reference **`TSTORE`/`TLOAD` + `base_elems`** vs recommended **`slot_size` / `slot_num` / GM pointer math** for `initialize_l2g2l_pipe` so Python authors can **match the reference packing** without guessing. +- Clarify explicitly: **cube `tpush` from ACC** is the supported producer; **do not** document MAT/LEFT `tpush` to GM as a FA workaround—that path **does not exist** in the reference. + +--- + +## 2. **`kTileFactor` / `Cube_S1` K-split and vec `Vec_S0` (reference-only geometry)** + +**What the reference does.** `kTileFactor = Tile_S1 / Cube_S1`; two matmul passes per logical tile; vec softmax tile **`Vec_S0 × Tile_S1`** with **`Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor`** (`runTFA`); **`compute_p`** loads **`kTileFactor`** column strips from GM. + +**What the Python builder does today (experimental).** **`compute_qk`**: **`kTileFactor`** passes (**`HEAD × Cube_S1`** K slices, **`tile.subview`** on **`qk_acc`** for **`S0 × Cube_S1`** acc columns), then one full **`S0 × S1_TILE`** `tpush`. **`compute_pv`**: still a **single** **`p_left × v_right → pv_acc`** matmul per tile (no K-split / **`AccMode`** striping like the reference yet). Vec **`tpop`** **`S0_HALF × S1_TILE`** with **`TILE_UP_DOWN`**. + +**Ask.** + +- **Documented** recipes (and, if useful, **ptodsl** helpers only—no new PTO ops required) for: cube **`AccMode`/`InitPartialSum`/`AccPartialSum`**-style sequences matching **`compute_pv`** (and any edge cases in **`compute_qk`**), and vec **`pto.load`** / **`slice_view`** patterns matching **`TLOAD`** + **`TASSIGN`** column offsets in `compute_p`. +- **`tile.subview`** Python API: **`sizes`** must be plain **`int`** (they are forwarded to MLIR `I64ArrayAttr`); **`const()`** wrappers currently fail at build time—document or unwrap in **`tile.subview`**. +- **`compute_pv` K-split in Python:** ref uses **`kTileFactor`** partial matmuls with **P** column strips. **`pto.tmatmul`** requires **`lhs` in LEFT**, but **LEFT** `tile.subview` on default boxed RowMajor **P** rejects column strips (`boxed RowMajor subview must keep full cols`). **MAT** **`p_recv`** allows the same column **`subview`** pattern as **`qk_acc`**, but **MAT cannot be `tmatmul` lhs** today. Reasonable ask: either **document** a supported recipe (e.g. **`TMOV`** staging + layout) or **relax verifier** / **extend `tmatmul`** only where it matches the reference matmul macro contract—not an invented new algorithm. +- If something in **MLIR verification** blocks a **literal** ref-shaped **`pto.store`**/`load` schedule that is otherwise valid, file that as a **narrow bugfix** with a ref citation—not a new feature. + +--- + +## 3. **Software row / subblock indexing (`row_slice`, `get_subblockid`)** + +**What the reference does.** `row_offset = subblock_base_rows + row_slice * Vec_S0` and reduce-tile **`TASSIGN`** byte offsets (`compute_p`)—**software** decomposition, not a request for new hardware **`split`** enum values unless the ISA already exposes them for the same pattern. + +**Ask.** + +- **Examples in docs** showing how to express the same indexing with **existing** Python/PTO constructs (`get_subblock_idx`, scalar offsets, multiple `load`s), so authors do not conflate **`TILE_UP_DOWN`** alone with the reference’s **`row_slice × kTileFactor`** schedule. + +--- + +## 4. **Sync: `TSync_Custom`, CV FIFO depth, `QK_PRELOAD` (exists in reference)** + +**What the reference does.** `TSync_Custom<…>` around QK produce / vec consume (`qk2smSync`); `should_wait_consumption` / `should_notify_consumption`; template **`QK_PRELOAD`**, **`qkp_tile_fifo_size`**, **`CV_FIFO_CONS_SYNC_PERIOD`** (`runTFA`, `compute_p`, `compute_qk`). + +**What the Python stack does today.** **`--enable-insert-sync`** plus pipe **`tfree`** / implicit ordering—**intentionally** different from hand-placed `TSync_Custom`. + +**Ask (reasonable).** + +- **Optional** dialect or pass hooks that lower to the **same** sync primitives the reference uses, **or** a documented **equivalence table**: “ref `qk2smSync.wait()` ↔ inserted barrier X after `ptoas` version Y”. +- This is **not** a license to invent new memory paths; it is **parity** for **control** the reference already has. + +--- + +## 5. **Python / `ptodsl` ergonomics (optional; ref-shaped constants)** + +**Problem.** Layout math (`GM_*_OFF_F32`, `MAT_*_OFF`, vec FIFO bytes) is hand-rolled and easy to break when `Cube_S0` / `Tile_S1` / `HEAD` change—**the reference avoids some of this** with template parameters and allocator helpers. + +**Ask.** + +- **Optional** helpers or tables driven only by **`runTFA`-style constants** (`CUBE_S0`, `CUBE_S1`, `TILE_S1`, `QK_PRELOAD`, FIFO sizes) to generate **GM strides matching `base_elems`**, **MAT/VEC base offsets**, and **static overlap checks**—without introducing new runtime algorithms. + +--- + +## 6. **Documentation cross-link (macros)** + +**Ask.** In PTOAS or `ptodsl` docs, link **`pto_macro_matmul` / `pto_macro_fa_softmax` / `pto_macro_fa_gu`** sequences to the **minimal** `tile.*` / `pto.*` sequences needed for lowering parity—**macro-internal** ops (including whatever the macro uses for P) are **in scope**; **invented** pre-pipeline **`tile.cvt`** on QK to fake a dtype the reference does not use on that path is **out of scope**. + +--- + +## Concrete examples (reference ↔ Python today) + +These snippets are for **PTOAS / `ptodsl` authors** mapping hand-written C++ (`fa_kernel.cpp`) to Python builders. Line citations use this repo’s paths. + +### A. QK: **`TSTORE`/`TLOAD`** (ref) vs **`tpush`/`tpop`** + ACC (Python) + +**Reference — cube writes one `Cube_S0 × Cube_S1` fp32 strip per `sub_tile_id` to GM** (`compute_qk`): + +```364:376:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + using GlobalDataQK = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + const uint32_t buf_idx = static_cast(tile_id % QKP_CV_FIFO); + const size_t base_elems = + static_cast(buf_idx) * static_cast(kTileFactor) * static_cast(Cube_S0) * + static_cast(Cube_S1) + + static_cast(sub_tile_id) * static_cast(Cube_S0) * static_cast(Cube_S1); + GlobalDataQK qkGlobalTile(qk_tile_fifo + base_elems); + TSTORE(qkGlobalTile, qkAccTile); +``` + +**Reference — vec reads `kTileFactor` GM strips and packs them into one wide `qkVecTile`** (`compute_p`): + +```505:517:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + for (int sub_col = 0; sub_col < static_cast(kTileFactor); ++sub_col) { + __gm__ float *qk_ptr_sub = + qk_ptr + static_cast(sub_col) * static_cast(Cube_S0) * static_cast(Cube_S1); + GlobalDataQK_Sub qkGlobalSub(qk_ptr_sub); + + TileDataF_Sub qkVecSub; + const uint64_t col_byte_offset = static_cast(sub_col * Cube_S1 * sizeof(float)); + TASSIGN(qkVecSub, (uint64_t)qkVecTile.data() + col_byte_offset); + TLOAD(qkVecSub, qkGlobalSub); + } +``` + +**Python builder today — cube:** `kTileFactor` **inner** matmuls match the ref’s **K** tiling, but the **GM path** is still one **`ACC`** tile **`tpush`** per logical tile (full `S0 × S1_TILE` fp32), not **`kTileFactor`** separate **`TSTORE`** slots. + +`accumulate_qk_for_tile` (K strips → **`qk_acc`** column subviews): + +```359:376:examples/aot/flash_attention/experimental/fa_builder.py + def accumulate_qk_for_tile(k_s1_offset, qk_buf, k_mat_buf_idx): + for sc in range(K_TILE_FACTOR): + k_col = k_s1_offset + const(sc * CUBE_S1) + kt_view = pto.slice_view( + kt_sub_slice_ty, + source=tv_k, + offsets=[c0, k_col], + sizes=[cHEAD, cCUBE_S1], + ) + pto.load(kt_view, k_mat[k_mat_buf_idx]) + tile.mov(k_mat[k_mat_buf_idx], k_right[k_mat_buf_idx]) + qk_part = tile.subview( + qk_acc[qk_buf], + [c0, const(sc * CUBE_S1)], + [S0, CUBE_S1], + ) + tile.matmul(q_left, k_right[k_mat_buf_idx], qk_part) +``` + +Prologue: one **`tpush`** per logical tile after **`accumulate_qk_for_tile`**: + +```387:391:examples/aot/flash_attention/experimental/fa_builder.py + for k in range(QK_PRELOAD): + k_s1_off = const(k * S1_TILE) + accumulate_qk_for_tile(k_s1_off, k, k % 2) + pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) +``` + +**Python builder today — vec:** one **`tpop`** of a **half-tile** `qk_vec_ty` (`S0_HALF × S1_TILE` fp32) per softmax step, not **`Vec_S0 × Cube_S1`** repeated **`TLOAD`**s: + +```571:605:examples/aot/flash_attention/experimental/fa_builder.py + def emit_softmax_step(exp_max_slot, is_init): + qk_recv = pto.tpop( + qk_vec_ty, + qk_pipe, + SPLIT_UP_DOWN, + addr=const(VEC_RECV_OFF, s.int64), + ) + tile.muls(qk_recv, scale, qk_recv) + tile.row_max(qk_recv, p_fp32, local_max) + # ... softmax body ... + tile.cvt(p_fp32, p_fp16) + pto.tpush_to_aic(p_fp16, SPLIT_UP_DOWN, id=ID_P) + pto.tfree(qk_pipe, SPLIT_UP_DOWN) +``` + +**Still missing vs ref (PTOAS / docs ask).** Either **document** how **`slot_size` / GM offsets** should reproduce **`base_elems` + `sub_tile_id`** packing, or show **`pto.store`/`pto.load`** that mirror **`TSTORE`/`TLOAD`** exactly. **Do not** document **`tile.cvt` on QK + `tpush` from MAT/LEFT`**; that is not the reference QK path. + +### A2. Schedule: **`runTFA`’s `sub_tile` loop** (C++) vs **fused Python steps** (same semantics, different shape) + +The reference **interleaves** cube **`compute_qk`** and vec **`compute_p`** at **`kTileFactor`** granularity inside the steady loop: + +```848:920:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + for (int preload_tile = 0; preload_tile < static_cast(qkPreloadNum) && preload_tile < num_tiles_s1; + ++preload_tile) { + if constexpr (DAV_CUBE) { + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + compute_qk<...>(preload_tile, sub_tile, ...); + } + } + if constexpr (DAV_VEC) { + for (int row_slice = 0; row_slice < static_cast(kTileFactor); ++row_slice) { + compute_p<...>(preload_tile, row_slice, ...); + } + } + } + + for (int tile_id = 0; tile_id < num_tiles_s1; ++tile_id) { + // ... + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + if constexpr (DAV_CUBE) { + if (next_qk_tile != -1) { + compute_qk<...>(next_qk_tile, sub_tile, ...); + } + } + if constexpr (DAV_VEC) { + if (next_qk_tile != -1) { + compute_p<...>(next_qk_tile, sub_tile, ...); + } + } + if constexpr (DAV_CUBE) { + compute_pv<...>(tile_id, sub_tile, ...); + } + } + if constexpr (DAV_VEC) { + compute_gu<...>(tile_id, ...); + } + } +``` + +**Python** (`experimental/fa_builder.py`) **fuses** all **`kTileFactor`** QK matmuls **inside** `accumulate_qk_for_tile` before **`tpush`**, and vec **`emit_softmax_step`** does **not** take explicit **`row_slice`** or **`sub_col`** arguments—**`TILE_UP_DOWN`** replaces part of the ref’s **`row_slice × Vec_S0`** story. A **PTOAS-facing doc** should spell out: “**`row_slice` loop** ↔ **`get_subblock_idx()` + fixed `S0_HALF`**” and “**`sub_col` GM loads** ↔ **either** replicated **`slice_view`** **or** one wide **`tpop`**,” so reviewers do not assume bit-identical control flow. + +### B. P: **V2C `TPipe` with `sizeof(half)`** (ref) vs vec **`tile.cvt`** + **`tpush_to_aic`** (Python) + +**Reference — P FIFO slot is fp16-sized, cube pops into MAT** (`fa_kernel.cpp`): + +```820:823:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + using PPipe = + TPipe; +``` + +**Python — same logical edge (vec → cube), different primitive names:** + +```603:604:examples/aot/flash_attention/experimental/fa_builder.py + tile.cvt(p_fp32, p_fp16) + pto.tpush_to_aic(p_fp16, SPLIT_UP_DOWN, id=ID_P) +``` + +**Cube consumer** still **`tpop`**s into **`p_recv_ty`** (MAT), then **`mov`** to **`p_left`** for **`matmul`**—analogous to **`TPOP`** into **`pMatTile`**. The **ask for PTOAS** is macro-order parity (**`pto_macro_fa_softmax`**) and **slot byte size** documentation, not new QK **`TCvt`** ideas. + +### C. **`PIPE_UNASSIGNED` on non-ACC `tpush`** + +**What fails today (do not suggest as FA “fix”):** + +```python +# Hypothetical / INVALID for QK in current PTO lowering: +pto.tpush(some_left_tile, qk_pipe, split=1) # LEFT → pipe: not a supported QK producer +pto.tpush(some_mat_tile, qk_pipe, split=1) # MAT → pipe: verifier / PIPE_UNASSIGNED +``` + +**What the Python FA builder uses instead (supported):** + +```python +pto.tpush(qk_acc_tile, qk_pipe, SPLIT_UP_DOWN) # ACC → l2g2l pipe (supported producer) +``` + +That matches the **intent** of ref **`TSTORE`** from **accumulator** data to a **visibility** buffer, even though the **packing** still differs from per-**`sub_tile_id`** **`TSTORE`** (§A). + +### D. PV: **`AccMode` + `kTileFactor`** in C++ vs **single `matmul`** in Python (verifier gap) + +**Reference — outer caller invokes `compute_pv` once per `sub_tile`** (`runTFA`): + +```912:918:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + if constexpr (DAV_CUBE) { + compute_pv( + tile_id, sub_tile, v, pv_tile_fifo_block, pMatTile[pv_src_pingpong_id % pMatTNBuffers], + vMatTile[pv_src_pingpong_id % vMatTNBuffers], pvAccTile, + pv_src_pingpong_id % vMatTNBuffers + PV_EVENT_ID0, pvAccTileEvtID, pPipe, pv2guSync); + pv_src_pingpong_id++; + } +``` + +**Reference — inner `AccMode` chooses init vs accumulate across K strips** (`compute_pv`): + +```400:433:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + const int s1_index = tile_id * static_cast(Tile_S1) + sub_tile_id * static_cast(Cube_S1); + // ... + GlobalVT vLoad((__gm__ half *)(v + s1_index * HEAD_SIZE)); + TLOAD(vMatTile, vLoad); + + TPOP(pPipe, pMatTile); + + const AccMode accMode = (sub_tile_id == 0) ? + (is_last_subtile ? AccMode::InitFinalSum : AccMode::InitPartialSum) : + (is_last_subtile ? AccMode::AccFinalSum : AccMode::AccPartialSum); + pto_macro_matmul(pMatTile, vMatTile, pvAccTile, accMode); +``` + +**What a literal Python port would look like** (conceptual; **not** all verifiable today): + +```python +# Desired shape: P is MAT, V is MAT, pv is ACC — ref uses pMatTile strips × v strips. +for sc in range(K_TILE_FACTOR): + p_sub = tile.subview(p_mat, [0, sc * CUBE_S1], [S0, CUBE_S1]) # column strip of P on MAT + v_sub = tile.subview(v_right, [sc * CUBE_S1, 0], [CUBE_S1, HEAD]) # row strip of V on RIGHT + if sc == 0: + tile.matmul(p_sub, v_sub, pv_acc) + else: + tile.matmul_acc(pv_acc, p_sub, v_sub, pv_acc) +``` + +**What actually blocks this in `ptodsl` today:** + +- **`tile.subview(p_left, …, [S0, CUBE_S1])`** (column strip on default boxed **LEFT** RowMajor **P**) → MLIR: **`boxed RowMajor subview must keep full cols`**. +- **`tile.matmul(p_sub_mat, v_sub, …)`** with **`p_sub_mat`** on **MAT** → MLIR: **`tmatmul` expects lhs in LEFT`**. + +So the **missing PTOAS / dialect / doc** piece is a **supported** way to express **ref `pto_macro_matmul` + `AccMode`** for **PV**—either **documented staging** (**`TMOV`** MAT→LEFT strips with a legal layout, or **`tmatmul`** generalization **only** where it matches the macro contract), not a new FA algorithm. + +**Python steady-state PV today** (one full matmul, ref-equivalent **math**, different **micro-schedule**): + +```408:423:examples/aot/flash_attention/experimental/fa_builder.py + p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) + tile.mov(p_raw, p_left[b]) + pto.tfree_from_aiv(SPLIT_UP_DOWN, id=ID_P) + tile.mov(v_mat[b], v_right[b]) + # ... prefetch next V into v_mat[1 - b] ... + tile.matmul(p_left[b], v_right[b], pv_acc[b]) + pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN) +``` + +### E. **`tile.subview` sizes: Python `int` vs `s.const` (MLIR `I64ArrayAttr`)** + +**Fails at Python build time** (`TypeError` from `IntegerAttr.get`): + +```python +cS0 = const(S0) +cCUBE_S1 = const(CUBE_S1) +qk_part = tile.subview(qk_acc, [c0, const(sc * CUBE_S1)], [cS0, cCUBE_S1]) # BAD: sizes are Value wrappers +``` + +**Works** (sizes are plain integers; offsets can stay dynamic via **`const`** / scalars as today): + +```python +qk_part = tile.subview(qk_acc, [c0, const(sc * CUBE_S1)], [S0, CUBE_S1]) +``` + +**Ask:** either **unwrap** **`const`** in **`ptodsl.api.tile.subview`** for **`sizes`**, or **document** this in PTOAS / `ptodsl` API reference so FA-style builders do not rediscover it via traceback. + +### F. Sync: **`TSync_Custom` + `wait`/`allocate`/`record`/`free`** (C++) vs **pipes + `tfree`** (Python) + +**Reference — explicit producer/consumer pairing on the QK path** (`compute_qk` / `compute_p`): + +```361:381:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + if (sub_tile_id == 0 && should_wait_consume) + qk2smSync.allocate(); // wait for SM consume data + // ... + TSTORE(qkGlobalTile, qkAccTile); + // ... + if (sub_tile_id == static_cast(kTileFactor) - 1) + qk2smSync.record(); // notify for QK produce data +``` + +```496:520:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + wait_flag(PIPE_V, PIPE_MTE2, pTileEventId); + if (row_slice == 0) + qk2smSync.wait(); // wait for QK produce data + // ... TLOAD kTileFactor strips ... + if (row_slice == static_cast(kTileFactor) - 1 && should_notify_consume) + qk2smSync.free(); // notify for SM consume data +``` + +**Sync object type** (template parameter baked into `runTFA`): + +```817:817:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + constexpr TSync_Custom qk2smSync = {BUF0_QK_READY}; +``` + +**Python — no named `TSync_Custom`; ordering relies on pipe ops + toolchain-inserted sync** (`--enable-insert-sync` in `experimental/compile.sh`): + +```279:281:examples/aot/flash_attention/experimental/fa_builder.py + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, +``` + +**Ask for PTOAS / docs:** a small **equivalence table** row per ref hook, e.g. **`qk2smSync.record()` after last `TSTORE` of a tile** ↔ **which `tpush` / GM completion barrier** in generated C++ after `ptoas` version *X*, so performance regressions can be bisected without reading all inserted barriers. + +### G. Row / column indexing: **`row_slice` + `Vec_S0`** (C++) vs **`get_subblock_idx` + `S0_HALF`** (Python) + +**Reference — software row origin per vec core and `row_slice`** (`compute_p`): + +```488:492:examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp + const size_t subblock_base_rows = + static_cast(Cube_S0 / VEC_CORES) * static_cast(get_subblockid()); + const size_t row_offset = subblock_base_rows + static_cast(row_slice * Vec_S0); + const int s0_index = blk_idx * Cube_S0 + row_offset; +``` + +**Python — subblock row origin** (half of **`Cube_S0`** per AIV sub-block for **`TILE_UP_DOWN`**): + +```540:541:examples/aot/flash_attention/experimental/fa_builder.py + sb_idx = s.index_cast(pto.get_subblock_idx()) + row_off_sb = sb_idx * cS0_HALF +``` + +There is **no** Python **`for row_slice in range(kTileFactor):`** around softmax; **`row_slice × Vec_S0`** from the ref is only partially reflected via **`S0_HALF`** recv tiles + pipe **`split`**. The **missing documentation** is a **direct mapping table** (`row_slice`, `Vec_S0`, `kTileFactor`) ↔ (`TILE_UP_DOWN`, `S0_HALF`, `pto.tpop` tile types), so authors do not treat **`split=1`** as a complete substitute for **ref softmax tile decomposition** without proof. + +--- + +## Priority (suggested) + +| Priority | Item | Rationale | +|----------|------|-----------| +| P0 | **§1–2** (docs + ref-shaped K/`Vec_S0` lowering recipes) | Largest **semantic** gap vs ref **without** inventing ops. | +| P1 | **§4** (sync parity / equivalence vs `TSync_Custom`) | Exists in ref; **`--enable-insert-sync`** is the only deliberate deviation today. | +| P2 | **§3, §5–6** (indexing examples, ergonomics, macro cross-links) | Reduces author error; no new hardware paths. | + +--- + +## Related artifacts in this repo + +- Experimental FA builder: `examples/aot/flash_attention/experimental/fa_builder.py` +- Reference C++ kernel: `examples/aot/flash_attention/cpp_ref/naive_tpush/fa_kernel.cpp` +- Broader gap narrative: `examples/aot/flash_attention/known_gap.md`