Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 87 additions & 102 deletions examples/aot/flash_attention/experimental/fa_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
#
# This file mirrors the reference C++ scheduler:
#
# constexpr int qkPreloadNum = 2; // warmup depth
# constexpr int qkPreloadNum = 2; // warmup depth (reference uses 4; UB-limited here)
#
# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec consumes them and
# pushes P[0..QK_PRELOAD-1]. No PV / gu yet. */
# Cube K tiling matches ref `Cube_S1` / `kTileFactor`: two matmuls per logical
# `S1_TILE` tile into `qk_acc` column subviews; K MAT uses ping-pong buffers
# (`kMatTNBuffers = 2` in ref).
#
# PV stays one `p_left × v_right` matmul: `tmatmul` requires **LEFT** lhs, but
# boxed RowMajor **LEFT** forbids P column `tile.subview` (`keep full cols`);
# MAT P allows column subviews yet cannot be `tmatmul` lhs — ref `compute_pv`
# K-strip schedule needs a future layout/op story in PTO DSL.
#
# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec softmaxes them. */
#
# /* Steady state, tile_id 0..N-1:
# cube: if (t+QK_PRELOAD < N) compute_qk(t+QK_PRELOAD);
# compute_pv(tile_id);
# vec: if (t+QK_PRELOAD < N) compute_p(t+QK_PRELOAD);
# compute_gu(tile_id);
# so vec's softmax for the LOOK-AHEAD tile fills the QK consumption
# slot WHILE the cube is computing the current PV[t]. The cube
# stops being blocked on a freshly-pushed P (softmax of t+2 has
# already pushed P[t+2] into the FIFO by the time cube needs it). */
# cube: compute QK[t+QK_PRELOAD] while draining PV[t];
# vec: interleave gu(t) with softmax(t+QK_PRELOAD) using exp_max a/b. */
#
# /* Epilogue: drain the last QK_PRELOAD tiles' PV / gu. */
#
Expand All @@ -27,6 +29,8 @@
# need a 2-deep ring of `exp_max` tiles (`exp_max_a`, `exp_max_b`). We
# implement the ring by unrolling the steady-state loop in pairs of 2
# iterations: even iters use `exp_max_a`, odd iters use `exp_max_b`.
# (Raising preload to 4 like the C++ launch requires a 4-slot ring and more
# VEC UB headroom; a prototype hit CCU address faults until layout is reworked.)
#
# Other state in the softmax (`new_global_max`, `new_global_sum`) does
# NOT need a ring: it is monotonic accumulator state across all tiles
Expand All @@ -49,10 +53,17 @@
# ---------------------------------------------------------------------------
# Static shapes (must match run.py constants)
# ---------------------------------------------------------------------------
S0 = 32 # Q rows per block
S0_HALF = S0 // 2 # rows per AIV sub-block
# Match reference `CUBE_S0 = 128` (`fa_kernel.cpp`). Vec UB grows with
# `S0` because the QK/PV pipe mirrors full tiles in UB; override with FA_S0
# if a smaller block is needed while tooling catches up (see `known_gap.md`).
S0 = int(os.environ.get("FA_S0", "128"))
S0_HALF = S0 // 2 # rows per AIV sub-block (TILE_UP_DOWN split)
HEAD = 128 # attention head dimension
S1_TILE = 256 # K/V columns per tile
S1_TILE = 256 # K/V columns per tile (logical `Tile_S1`; ref `fa_kernel.cpp`)
# Reference: `Cube_S1` K micro-tile and `kTileFactor = Tile_S1 / Cube_S1` matmul passes.
CUBE_S1 = int(os.environ.get("FA_CUBE_S1", "128"))
assert S1_TILE % CUBE_S1 == 0, "S1_TILE must be divisible by FA_CUBE_S1"
K_TILE_FACTOR = S1_TILE // CUBE_S1
# NUM_TILES is overridable via the FA_NUM_TILES env var so the same builder
# can produce kernels for different sequence lengths
# (S1_TOTAL = S1_TILE * NUM_TILES).
Expand All @@ -62,21 +73,20 @@
S1_TOTAL = S1_TILE * NUM_TILES

Q_ROWS = 2048
NUM_Q_BLOCKS = Q_ROWS // S0 # 64 row-blocks
NUM_Q_BLOCKS = Q_ROWS // S0 # e.g. 16 when Q_ROWS=2048 and S0=128

# QK preload depth — must be >= 1; reference uses 2. The vec pre-softmaxes
# tiles 0..QK_PRELOAD-1, then the steady-state loop interleaves softmax(t+QK_PRELOAD)
# with gu(t), and the epilogue drains the last QK_PRELOAD gu's.
# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled to
# ping-pong the exp_max ring (see below).
# QK preload depth — must be >= 1; reference launch uses 4, this builder
# keeps 2 for a smaller VEC exp_max ring (see header comment).
# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled.
QK_PRELOAD = 2
assert (
NUM_TILES - QK_PRELOAD
) % 2 == 0, "Steady-state pair unrolling requires (NUM_TILES - QK_PRELOAD) % 2 == 0"
STEADY_PAIRS = (NUM_TILES - QK_PRELOAD) // 2

# Per-pipe slot sizes (bytes).
SLOT_SIZE_QK = S0 * S1_TILE * 4 # fp32 QK accumulator
# QK: fp32 on the wire (cube `tpush` only accepts ACC tiles today — see known_gap).
SLOT_SIZE_QK = S0 * S1_TILE * 4
SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator
SLOT_SIZE_P = S0 * S1_TILE * 2 # fp16 P matrix sent vec → cube

Expand All @@ -103,10 +113,15 @@
# Explicit local-memory layout used when compiling with --pto-level=level3.
# Offsets are byte offsets within each independent local address space.
MAT_Q_OFF = 0
MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2
MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2
# Two physical K tiles (ref `kMatTNBuffers = 2`) for ping-pong loads.
MAT_K0_OFF = MAT_Q_OFF + S0 * HEAD * 2
MAT_K1_OFF = MAT_K0_OFF + HEAD * CUBE_S1 * 2
MAT_P_RECV_OFF = MAT_K1_OFF + HEAD * CUBE_S1 * 2
MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2
MAT_P_FIFO_OFF = 262144
MAT_P_FIFO_OFF = MAT_V_OFF + S1_TILE * HEAD * 2
# Pad past the last MAT-resident tile; bisheng is sensitive to overlap here.
if MAT_P_FIFO_OFF < 393216:
MAT_P_FIFO_OFF = 393216

LEFT_Q_OFF = 0
LEFT_P_OFF = LEFT_Q_OFF + S0 * HEAD * 2
Expand All @@ -123,13 +138,16 @@
VEC_P_FP16_OFF = VEC_P_FP32_OFF + S0_HALF * S1_TILE * 4
VEC_O_OFF = VEC_P_FP16_OFF + S0_HALF * S1_TILE * 2
VEC_RED_BASE_OFF = VEC_O_OFF + S0_HALF * HEAD * 4
VEC_RED_STRIDE = 512
# Tight packing for reduce / exp_max ring scalars (one column per logical row).
VEC_RED_STRIDE = ((S0_HALF * 4 + 127) // 128) * 128
VEC_NEW_GLOBAL_MAX_OFF = VEC_RED_BASE_OFF + 0 * VEC_RED_STRIDE
VEC_LOCAL_MAX_OFF = VEC_RED_BASE_OFF + 1 * VEC_RED_STRIDE
VEC_NEW_GLOBAL_SUM_OFF = VEC_RED_BASE_OFF + 2 * VEC_RED_STRIDE
VEC_LOCAL_SUM_OFF = VEC_RED_BASE_OFF + 3 * VEC_RED_STRIDE
VEC_EXP_MAX_A_OFF = VEC_RED_BASE_OFF + 4 * VEC_RED_STRIDE
VEC_EXP_MAX_B_OFF = VEC_RED_BASE_OFF + 5 * VEC_RED_STRIDE
# Shared recv scratch: max(fp32 QK half-tile, fp32 PV half-tile) for tpop addr=.
_VEC_RECV_BYTES = max(S0_HALF * S1_TILE * 4, S0_HALF * HEAD * 4)
VEC_RECV_OFF = VEC_RED_BASE_OFF + 6 * VEC_RED_STRIDE

ID_QK = 10 # Cube → Vec, dir_mask = 1 (uses lower-level l2g2l)
Expand All @@ -155,6 +173,7 @@ def meta_data():

q_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp16)
kt_sub_ty = pto.SubTensorType(shape=[HEAD, S1_TILE], dtype=fp16)
kt_sub_slice_ty = pto.SubTensorType(shape=[HEAD, CUBE_S1], dtype=fp16)
v_sub_ty = pto.SubTensorType(shape=[S1_TILE, HEAD], dtype=fp16)
o_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp32)
o_sub_half_ty = pto.SubTensorType(shape=[S0_HALF, HEAD], dtype=fp32)
Expand All @@ -163,13 +182,13 @@ def meta_data():
q_mat_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="MAT")
q_left_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="LEFT")
k_mat_ty = pto.TileBufType(
shape=[HEAD, S1_TILE],
shape=[HEAD, CUBE_S1],
dtype=fp16,
memory_space="MAT",
config=pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor"),
)
k_right_ty = pto.TileBufType(
shape=[HEAD, S1_TILE], dtype=fp16, memory_space="RIGHT"
shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="RIGHT"
)
qk_acc_ty = pto.TileBufType(shape=[S0, S1_TILE], dtype=fp32, memory_space="ACC")
p_recv_ty = pto.TileBufType(
Expand Down Expand Up @@ -229,11 +248,11 @@ def cube_kernel(
) -> None:
c0 = const(0)
c1 = const(1)
c2 = const(2)
cS0 = const(S0)
cHEAD = const(HEAD)
cS1_TILE = const(S1_TILE)
cS1_TOTAL = const(S1_TOTAL)
cCUBE_S1 = const(CUBE_S1)
cNUM_TILES = const(NUM_TILES)
cNUM_Q_BLOCKS = const(NUM_Q_BLOCKS)

Expand Down Expand Up @@ -297,25 +316,21 @@ def cube_kernel(
nosplit=False,
)

# All cube tile-buffers are single-buffered. K and V share RIGHT
# storage: for HEAD=128, S1_TILE=256 each RIGHT tile is exactly
# 64 KB, and the schedule uses V for PV before moving K for QK.
# This mirrors the hand-written reference's explicit local-memory
# assignment style and avoids asking RIGHT for two full tiles.
# K and V share one RIGHT bank (sequential use); `[buf]` lists alias
# to the same physical tile for each role (schedule is still safe).
right_base = const(RIGHT_KV_OFF, s.int64)
q_mat = pto.alloc_tile(q_mat_ty, addr=const(MAT_Q_OFF, s.int64))
q_left = pto.alloc_tile(q_left_ty, addr=const(LEFT_Q_OFF, s.int64))
k_mat_s = pto.alloc_tile(k_mat_ty, addr=const(MAT_K_OFF, s.int64))
k_mat_0 = pto.alloc_tile(k_mat_ty, addr=const(MAT_K0_OFF, s.int64))
k_mat_1 = pto.alloc_tile(k_mat_ty, addr=const(MAT_K1_OFF, s.int64))
k_right_s = pto.alloc_tile(k_right_ty, addr=right_base)
qk_acc_s = pto.alloc_tile(qk_acc_ty, addr=const(ACC_QK_OFF, s.int64))
p_recv_s = pto.alloc_tile(p_recv_ty, addr=const(MAT_P_RECV_OFF, s.int64))
p_left_s = pto.alloc_tile(p_left_ty, addr=const(LEFT_P_OFF, s.int64))
v_mat_s = pto.alloc_tile(v_mat_ty, addr=const(MAT_V_OFF, s.int64))
v_right_s = pto.alloc_tile(v_right_ty, addr=right_base)
pv_acc_s = pto.alloc_tile(pv_acc_ty, addr=const(ACC_PV_OFF, s.int64))
# Aliasing wrappers: keep the per-iteration `[buf]` indexing pattern
# in the body even though all slots currently point at one alloc.
k_mat = [k_mat_s, k_mat_s]
k_mat = [k_mat_0, k_mat_1]
k_right = [k_right_s, k_right_s]
qk_acc = [qk_acc_s, qk_acc_s]
p_recv = [p_recv_s, p_recv_s]
Expand All @@ -341,6 +356,25 @@ def cube_kernel(
strides=[cHEAD, c1],
)

# Two `Cube_S1`-wide matmuls per logical tile (ref `compute_qk` / `kTileFactor`).
def accumulate_qk_for_tile(k_s1_offset, qk_buf, k_mat_buf_idx):
for sc in range(K_TILE_FACTOR):
k_col = k_s1_offset + const(sc * CUBE_S1)
kt_view = pto.slice_view(
kt_sub_slice_ty,
source=tv_k,
offsets=[c0, k_col],
sizes=[cHEAD, cCUBE_S1],
)
pto.load(kt_view, k_mat[k_mat_buf_idx])
tile.mov(k_mat[k_mat_buf_idx], k_right[k_mat_buf_idx])
qk_part = tile.subview(
qk_acc[qk_buf],
[c0, const(sc * CUBE_S1)],
[S0, CUBE_S1],
)
tile.matmul(q_left, k_right[k_mat_buf_idx], qk_part)

for qb in pto.range(qb_start, qb_end, c1):
q_row_off = qb * cS0

Expand All @@ -351,19 +385,9 @@ def cube_kernel(
tile.mov(q_mat, q_left)

# =================== Cube prologue: emit QK[0..QK_PRELOAD-1] ===================
# Each prologue QK uses its own k_mat / k_right / qk_acc slot
# so MTE2 load of K[1] overlaps the M of QK[0].
for k in range(QK_PRELOAD):
k_off = const(k * S1_TILE)
kt_view_k = pto.slice_view(
kt_sub_ty,
source=tv_k,
offsets=[c0, k_off],
sizes=[cHEAD, cS1_TILE],
)
pto.load(kt_view_k, k_mat[k])
tile.mov(k_mat[k], k_right[k])
tile.matmul(q_left, k_right[k], qk_acc[k])
k_s1_off = const(k * S1_TILE)
accumulate_qk_for_tile(k_s1_off, k, k % 2)
pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN)

# Preload V[0] for the very first PV.
Expand All @@ -376,25 +400,10 @@ def cube_kernel(
pto.load(v_view_0, v_mat[0])

# =================== Cube steady state ===================
# Pair-unrolled. Iter t (parity = t%2 → buffer index `b`):
# * load K[next_qk = t+QK_PRELOAD] into k_mat[b]
# (next_qk parity equals t parity since QK_PRELOAD == 2)
# * pop / mov P[t] into p_left[b]; mov V[t] (in v_mat[b]) → v_right[b]
# * preload V[t+1] into v_mat[1-b]
# * matmul PV[t] into pv_acc[b]; push
# * matmul QK[next_qk] into qk_acc[b]; push
# Pair handler:
# Pair-unrolled; buffer index b = t % 2 (logical ping-pong).
def emit_cube_step(t_idx, b):
# next_qk = t_idx + QK_PRELOAD (only used when in main range)
next_qk = t_idx + const(QK_PRELOAD)
kt_off = next_qk * cS1_TILE
kt_view = pto.slice_view(
kt_sub_ty,
source=tv_k,
offsets=[c0, kt_off],
sizes=[cHEAD, cS1_TILE],
)
pto.load(kt_view, k_mat[b])
k_s1_off = next_qk * cS1_TILE

p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P)
tile.mov(p_raw, p_left[b])
Expand All @@ -413,23 +422,20 @@ def emit_cube_step(t_idx, b):
tile.matmul(p_left[b], v_right[b], pv_acc[b])
pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN)

tile.mov(k_mat[b], k_right[b])
tile.matmul(q_left, k_right[b], qk_acc[b])
accumulate_qk_for_tile(k_s1_off, b, b)
pto.tpush(qk_acc[b], qk_pipe, SPLIT_UP_DOWN)

assert (NUM_TILES - QK_PRELOAD) % 2 == 0
for p in pto.range(c0, const((NUM_TILES - QK_PRELOAD) // 2), c1):
c2 = const(2)
for p in pto.range(c0, const(STEADY_PAIRS), c1):
t_a = p * c2
emit_cube_step(t_a, 0)
t_b = p * c2 + c1
emit_cube_step(t_b, 1)

# =================== Cube epilogue: drain last QK_PRELOAD PVs ===================
# Tile_id range: NUM_TILES-QK_PRELOAD .. NUM_TILES-1.
# NUM_TILES is even and QK_PRELOAD is even, so the first epilogue
# tile has parity 0. v_mat[0] holds V[NUM_TILES-QK_PRELOAD] thanks
# to the last steady-state preload (it loaded V[t_b+1] = V[NUM_TILES-QK_PRELOAD]
# into v_mat[1-1]=v_mat[0]).
# Tile ids NUM_TILES-QK_PRELOAD .. NUM_TILES-1; v_mat parity matches
# the last steady-state V preload pattern (same as QK_PRELOAD==2 case).
for k in range(QK_PRELOAD):
b = k % 2
p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P)
Expand Down Expand Up @@ -461,7 +467,6 @@ def vector_kernel(
) -> None:
c0 = const(0)
c1 = const(1)
c2 = const(2)
cS0 = const(S0)
cS0_HALF = const(S0_HALF)
cHEAD = const(HEAD)
Expand Down Expand Up @@ -547,13 +552,7 @@ def vector_kernel(
red_ty, addr=const(VEC_NEW_GLOBAL_SUM_OFF, s.int64)
)
local_sum = pto.alloc_tile(red_ty, addr=const(VEC_LOCAL_SUM_OFF, s.int64))
# Ring of QK_PRELOAD exp_max tiles. With QK_PRELOAD=2 we use a/b
# ping-pong: even-parity tiles use exp_max_a, odd-parity tiles use
# exp_max_b. softmax(t) writes the exp_max for tile t into the
# corresponding slot; gu(t) reads it from the same slot. Because
# softmax(t+QK_PRELOAD) and gu(t) hit the SAME slot (parity matches),
# the steady-state loop must do gu(t) BEFORE softmax(t+QK_PRELOAD)
# to avoid clobbering.
# Two exp_max tiles (a/b). Interleave gu and softmax to avoid clobber.
assert QK_PRELOAD == 2, "exp_max ring is hard-coded to 2 tiles"
exp_max_a = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_A_OFF, s.int64))
exp_max_b = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_B_OFF, s.int64))
Expand All @@ -577,7 +576,7 @@ def emit_softmax_step(exp_max_slot, is_init):
addr=const(VEC_RECV_OFF, s.int64),
)
tile.muls(qk_recv, scale, qk_recv)
tile.row_max(qk_recv, tmp_tile, local_max)
tile.row_max(qk_recv, p_fp32, local_max)

local_max_r = tile.reshape(red_row_ty, local_max)
new_global_max_r = tile.reshape(red_row_ty, new_global_max)
Expand Down Expand Up @@ -625,29 +624,15 @@ def emit_gu_step(exp_max_slot, is_init):
o_row_off = qb * cS0

# =================== Vec prologue: softmax(0..QK_PRELOAD-1) ===================
# softmax(0): is_init=True (writes exp_max_a, but exp_max_a for tile 0
# is unused by gu(0) — gu(0) takes the init branch and just movs PV.
# Still we must compute it correctly; the init branch doesn't touch exp_max.
emit_softmax_step(exp_max_a, is_init=True)
# softmax(1): is_init=False (writes exp_max_b)
emit_softmax_step(exp_max_b, is_init=False)

# =================== Vec steady state ===================
# Pair-unrolled: each `p` iteration handles tiles t_a = 2p, t_b = 2p+1.
# gu(t_a) reads exp_max_a (set by softmax(t_a) earlier);
# softmax(t_a+2) writes exp_max_a (matches parity).
# gu(t_b) reads exp_max_b; softmax(t_b+2) writes exp_max_b.
# CRITICAL: gu BEFORE softmax in same step to avoid clobbering.
#
# First pair (p=0, t_a=0, t_b=1) is Python-unrolled so we can
# take the `is_init=True` branch on gu(0) (which initializes
# o_tile via mov rather than rescale+add).
emit_gu_step(exp_max_a, is_init=True) # tile 0
emit_softmax_step(exp_max_a, is_init=False) # tile 2 → exp_max_a
emit_gu_step(exp_max_b, is_init=False) # tile 1
emit_softmax_step(exp_max_b, is_init=False) # tile 3 → exp_max_b

# Remaining pairs (p=1..STEADY_PAIRS-1) inside a runtime loop.
emit_gu_step(exp_max_a, is_init=True)
emit_softmax_step(exp_max_a, is_init=False)
emit_gu_step(exp_max_b, is_init=False)
emit_softmax_step(exp_max_b, is_init=False)

for p in pto.range(c1, const(STEADY_PAIRS), c1):
emit_gu_step(exp_max_a, is_init=False)
emit_softmax_step(exp_max_a, is_init=False)
Expand Down
Loading
Loading