Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 92 additions & 63 deletions primus_turbo/flydsl/gemm/gemm_fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ceildiv,
compute_global_swizzle,
compute_global_swizzle_nn,
make_fp8_buffer_tensor,
make_fp8_buffer_tensor_rebased,
make_value_attrs,
mask_a_tail,
wait_barrier,
Expand All @@ -38,6 +38,7 @@
from flydsl._mlir.dialects import llvm as _llvm
from flydsl.expr import arith
from flydsl.expr import range_constexpr, rocdl
from flydsl.expr.typing import T

# isort: on

Expand Down Expand Up @@ -143,13 +144,23 @@ def kernel_dense_nt(
block_m = first_pid_m + (pid_in_group % group_size_m)
block_n = pid_in_group // group_size_m

A0_gl_offset = (block_m * BLOCK_M) * K
A1_gl_offset = (block_m * BLOCK_M + LDS_BLOCK_M) * K
B0_gl_offset = (block_n * BLOCK_N) * K
B1_gl_offset = (block_n * BLOCK_N + LDS_BLOCK_N) * K

gA = make_fp8_buffer_tensor(A, F8_IR_t)
gB = make_fp8_buffer_tensor(B_T, F8_IR_t)
# i64 input re-base: fold the per-tile row base (m_row*K, n_row*K) into the
# SRD base; A/B_T K-contiguous (foldable), k*BLOCK_K small int32 -> no cap.
a_base = arith.index_cast(T.index, block_m * BLOCK_M) * arith.index(K)
b_base = arith.index_cast(T.index, block_n * BLOCK_N) * arith.index(K)
a_nrec = (
arith.index_cast(T.index, c_m) - arith.index_cast(T.index, block_m * BLOCK_M)
) * arith.index(K)
b_nrec = (
arith.index_cast(T.index, c_n) - arith.index_cast(T.index, block_n * BLOCK_N)
) * arith.index(K)
A0_gl_offset = 0
A1_gl_offset = LDS_BLOCK_M * K
B0_gl_offset = 0
B1_gl_offset = LDS_BLOCK_N * K

gA = make_fp8_buffer_tensor_rebased(A, F8_IR_t, a_base, a_nrec)
gB = make_fp8_buffer_tensor_rebased(B_T, F8_IR_t, b_base, b_nrec)
a_div = fx.logical_divide(gA, fx.make_layout(1, 1))
b_div = fx.logical_divide(gB, fx.make_layout(1, 1))

Expand Down Expand Up @@ -458,18 +469,21 @@ def kernel_dense_nn(
block_m = first_pid_m + (pid_in_group % group_size_m)
block_n = pid_in_group // group_size_m

# A: same as NT.
A0_gl_offset = (block_m * BLOCK_M) * K
A1_gl_offset = (block_m * BLOCK_M + LDS_BLOCK_M) * K

# B: NN-specific. B is [K, N] row-major; per WG we load BLOCK_K K-rows
# × BLOCK_N N-cols, split into 2 N-halves of LDS_BLOCK_N each. K-iter
# step advances K-rows by BLOCK_K, which in element units is BLOCK_K * c_n.
B0_gl_offset = block_n * BLOCK_N + 0
B1_gl_offset = block_n * BLOCK_N + LDS_BLOCK_N

gA = make_fp8_buffer_tensor(A, F8_IR_t)
gB = make_fp8_buffer_tensor(B, F8_IR_t)
# i64 input re-base. A[M,K]: fold row base (m_row*K) into SRD. B[K,N]: the
# k*BLOCK_K*c_n contraction is i64 per load (cn_i), capped at 4GB by num_records.
m_row = block_m * BLOCK_M
cn_i = arith.index_cast(T.index, c_n)
a_base = arith.index_cast(T.index, m_row) * arith.index(K)
a_nrec = (arith.index_cast(T.index, c_m) - arith.index_cast(T.index, m_row)) * arith.index(K)
b_base = arith.index_cast(T.index, block_n * BLOCK_N)
b_nrec = arith.index(K) * cn_i - b_base
A0_gl_offset = 0
A1_gl_offset = LDS_BLOCK_M * K
B0_gl_offset = 0
B1_gl_offset = LDS_BLOCK_N

gA = make_fp8_buffer_tensor_rebased(A, F8_IR_t, a_base, a_nrec)
gB = make_fp8_buffer_tensor_rebased(B, F8_IR_t, b_base, b_nrec)
a_div = fx.logical_divide(gA, fx.make_layout(1, 1))
b_div = fx.logical_divide(gB, fx.make_layout(1, 1))

Expand Down Expand Up @@ -498,19 +512,19 @@ def kernel_dense_nn(
c11_frag = [mfma.zero_value] * N_ACCUMS

# Prelude.
b_g2s.load(b_cur0, B0_gl_offset + 0 * BLOCK_K * c_n)
b_g2s.load(b_cur0, B0_gl_offset + arith.index(0 * BLOCK_K) * cn_i)
a_g2s.load(a_cur0, A0_gl_offset + 0 * BLOCK_K)
b_g2s.load(b_cur1, B1_gl_offset + 0 * BLOCK_K * c_n)
b_g2s.load(b_cur1, B1_gl_offset + arith.index(0 * BLOCK_K) * cn_i)
a_g2s.load(a_cur1, A1_gl_offset + 0 * BLOCK_K)

if wave_m == 1:
rocdl.s_barrier()

wait_barrier(N_LDS_STEPS_A + N_LDS_STEPS_B)

b_g2s.load(b_next0, B0_gl_offset + 1 * BLOCK_K * c_n)
b_g2s.load(b_next0, B0_gl_offset + arith.index(1 * BLOCK_K) * cn_i)
a_g2s.load(a_next0, A0_gl_offset + 1 * BLOCK_K)
b_g2s.load(b_next1, B1_gl_offset + 1 * BLOCK_K * c_n)
b_g2s.load(b_next1, B1_gl_offset + arith.index(1 * BLOCK_K) * cn_i)

wait_barrier(N_LDS_STEPS_A + 2 * N_LDS_STEPS_B)

Expand All @@ -528,7 +542,7 @@ def kernel_dense_nn(
rocdl.s_barrier()

b1_frag = b_s2r.load(b_cur1)
b_g2s.load(b_cur0, B0_gl_offset + (k + 2) * BLOCK_K * c_n)
b_g2s.load(b_cur0, B0_gl_offset + arith.index((k + 2) * BLOCK_K) * cn_i)
rocdl.s_barrier()

rocdl.s_setprio(1)
Expand All @@ -545,7 +559,7 @@ def kernel_dense_nn(
rocdl.s_setprio(0)
rocdl.s_barrier()

b_g2s.load(b_cur1, B1_gl_offset + (k + 2) * BLOCK_K * c_n)
b_g2s.load(b_cur1, B1_gl_offset + arith.index((k + 2) * BLOCK_K) * cn_i)
wait_barrier(2 * N_LDS_STEPS_A + N_LDS_STEPS_B)

rocdl.s_setprio(1)
Expand Down Expand Up @@ -791,16 +805,21 @@ def kernel_dense_tn(
# for one Python-selected path (1D GROUP_M or 2D band).
block_m, block_n = _tn_block_mn(pid, num_pid_m, n_blocks, GROUP_M, group_n)

# TN A stored [K, M] row-major: stride M per K-row.
A0_gl_offset = block_m * BLOCK_M + 0
A1_gl_offset = block_m * BLOCK_M + LDS_BLOCK_M

# B same as NN: stored [K, N] row-major.
B0_gl_offset = block_n * BLOCK_N + 0
B1_gl_offset = block_n * BLOCK_N + LDS_BLOCK_N

gA = make_fp8_buffer_tensor(A, F8_IR_t)
gB = make_fp8_buffer_tensor(B, F8_IR_t)
# i64 input re-base. A[K,M]/B[K,N] K-row-major: fold column base into SRD; the
# k*BLOCK_K*c_{m,n} traversal is i64 per load (int32 wraps > 2^31), capped at 4GB.
cm_i = arith.index_cast(T.index, c_m)
cn_i = arith.index_cast(T.index, c_n)
a_base = arith.index_cast(T.index, block_m) * arith.index(BLOCK_M)
b_base = arith.index_cast(T.index, block_n) * arith.index(BLOCK_N)
a_nrec = arith.index(K) * cm_i - a_base
b_nrec = arith.index(K) * cn_i - b_base
A0_gl_offset = 0
A1_gl_offset = LDS_BLOCK_M
B0_gl_offset = 0
B1_gl_offset = LDS_BLOCK_N

gA = make_fp8_buffer_tensor_rebased(A, F8_IR_t, a_base, a_nrec)
gB = make_fp8_buffer_tensor_rebased(B, F8_IR_t, b_base, b_nrec)
a_div = fx.logical_divide(gA, fx.make_layout(1, 1))
b_div = fx.logical_divide(gB, fx.make_layout(1, 1))

Expand Down Expand Up @@ -835,19 +854,19 @@ def kernel_dense_tn(
c11_frag = [mfma.zero_value] * N_ACCUMS

# Prelude.
b_g2s.load(b_cur0, B0_gl_offset + 0 * BLOCK_K * c_n)
a_g2s.load(a_cur0, A0_gl_offset + 0 * BLOCK_K * c_m)
b_g2s.load(b_cur1, B1_gl_offset + 0 * BLOCK_K * c_n)
a_g2s.load(a_cur1, A1_gl_offset + 0 * BLOCK_K * c_m)
b_g2s.load(b_cur0, B0_gl_offset + arith.index(0 * BLOCK_K) * cn_i)
a_g2s.load(a_cur0, A0_gl_offset + arith.index(0 * BLOCK_K) * cm_i)
b_g2s.load(b_cur1, B1_gl_offset + arith.index(0 * BLOCK_K) * cn_i)
a_g2s.load(a_cur1, A1_gl_offset + arith.index(0 * BLOCK_K) * cm_i)

if wave_m == 1:
rocdl.s_barrier()

wait_barrier(N_LDS_STEPS_A + N_LDS_STEPS_B)

b_g2s.load(b_next0, B0_gl_offset + 1 * BLOCK_K * c_n)
a_g2s.load(a_next0, A0_gl_offset + 1 * BLOCK_K * c_m)
b_g2s.load(b_next1, B1_gl_offset + 1 * BLOCK_K * c_n)
b_g2s.load(b_next0, B0_gl_offset + arith.index(1 * BLOCK_K) * cn_i)
a_g2s.load(a_next0, A0_gl_offset + arith.index(1 * BLOCK_K) * cm_i)
b_g2s.load(b_next1, B1_gl_offset + arith.index(1 * BLOCK_K) * cn_i)

wait_barrier(N_LDS_STEPS_A + 2 * N_LDS_STEPS_B)

Expand All @@ -862,27 +881,27 @@ def kernel_dense_tn(
# drain — c01 consumes b1 with no covering drain between.)
b0_frag = b_s2r.load(b_cur0, drain=False)
a0_frag = a_s2r.load(a_cur0)
a_g2s.load(a_next1, A1_gl_offset + (k + 1) * BLOCK_K * c_m)
a_g2s.load(a_next1, A1_gl_offset + arith.index((k + 1) * BLOCK_K) * cm_i)
rocdl.s_barrier()
rocdl.s_setprio(1)
c00_frag = mfma.call(a0_frag, b0_frag, c00_frag)
rocdl.s_setprio(0)
rocdl.s_barrier()
b1_frag = b_s2r.load(b_cur1)
b_g2s.load(b_cur0, B0_gl_offset + (k + 2) * BLOCK_K * c_n)
b_g2s.load(b_cur0, B0_gl_offset + arith.index((k + 2) * BLOCK_K) * cn_i)
rocdl.s_barrier()
rocdl.s_setprio(1)
c01_frag = mfma.call(a0_frag, b1_frag, c01_frag)
rocdl.s_setprio(0)
rocdl.s_barrier()
a1_frag = a_s2r.load(a_cur1)
a_g2s.load(a_cur0, A0_gl_offset + (k + 2) * BLOCK_K * c_m)
a_g2s.load(a_cur0, A0_gl_offset + arith.index((k + 2) * BLOCK_K) * cm_i)
rocdl.s_barrier()
rocdl.s_setprio(1)
c10_frag = mfma.call(a1_frag, b0_frag, c10_frag)
rocdl.s_setprio(0)
rocdl.s_barrier()
b_g2s.load(b_cur1, B1_gl_offset + (k + 2) * BLOCK_K * c_n)
b_g2s.load(b_cur1, B1_gl_offset + arith.index((k + 2) * BLOCK_K) * cn_i)
wait_barrier(2 * N_LDS_STEPS_A + N_LDS_STEPS_B)
rocdl.s_setprio(1)
c11_frag = mfma.call(a1_frag, b1_frag, c11_frag)
Expand Down Expand Up @@ -915,7 +934,7 @@ def kernel_dense_tn(
rocdl.s_setprio(0)
rocdl.s_barrier()
b0_frag = b_s2r.load(b_next0)
a_g2s.load(a_next1, A1_gl_offset + (k + 1) * BLOCK_K * c_m)
a_g2s.load(a_next1, A1_gl_offset + arith.index((k + 1) * BLOCK_K) * cm_i)
rocdl.s_barrier()
rocdl.s_setprio(1)
c11_frag = mfma.call(a1_frag, b1_frag, c11_frag)
Expand Down Expand Up @@ -1008,14 +1027,27 @@ def _get_compiled_dense(launch, args):
return cached


def _run_dense(entry, args):
"""Mode-split steady-state launch. entry = [raw @flyc.jit launch, cfg, compiled].
Eager: run the one-time flyc.compile'd object (skips @flyc.jit's per-call drift-
check + arg-hash, and the per-call arg-key rebuild). Capture: run the raw closure
(a flyc.compile'd object regresses under CUDA-graph capture)."""
if torch.cuda.is_current_stream_capturing():
entry[0](*args)
else:
if entry[2] is None:
entry[2] = flyc.compile(entry[0], *args)
entry[2](*args)


def _as_i8_flat(t: torch.Tensor) -> torch.Tensor:
# Zero-copy flat byte view. Recomputed every call (no id()-keyed cache: a
# Zero-copy i8 view. Recomputed every call (no id()-keyed cache: a
# freed tensor's id + data_ptr can both be reused, and a recycled pair with a
# different numel would alias the wrong length). The view ops are ~1us and
# allocate nothing.
Comment thread
kyle-256 marked this conversation as resolved.
if t.element_size() == 1 and t.dtype != torch.int8: # fp8
return t.contiguous().view(torch.int8).view(-1)
return t.contiguous().view(-1)
return t.contiguous().view(torch.int8)
return t.contiguous()
Comment thread
kyle-256 marked this conversation as resolved.


def _scalar_scale(scale: torch.Tensor, device: torch.device) -> torch.Tensor:
Expand Down Expand Up @@ -1092,7 +1124,7 @@ def _autotune_nn_dispatch(args, M, N, K, cbsz=0, blgp=0, out_fp16=False):
us = e0.elapsed_time(e1) * 1000.0 / 20
if us < best_us:
best_us = us
best = (launch, (bm, gm, xcd, ag))
best = [launch, (bm, gm, xcd, ag), c] # c: compiled winner (reused eager)
except Exception:
continue
if best is None:
Expand Down Expand Up @@ -1164,7 +1196,7 @@ def _autotune_nt_dispatch(args, M, N, K, cbsz=0, blgp=0, out_fp16=False):
us = e0.elapsed_time(e1) * 1000.0 / 20
if us < best_us:
best_us = us
best = (launch, (bm, gm, xcd, ag))
best = [launch, (bm, gm, xcd, ag), c] # c: compiled winner (reused eager)
except Exception:
continue
if best is None:
Expand Down Expand Up @@ -1237,7 +1269,7 @@ def _autotune_tn_dispatch(args, M, N, K, cbsz=0, blgp=0, out_fp16=False):
us = e0.elapsed_time(e1) * 1000.0 / 20
if us < best_us:
best_us = us
best = (launch, (bm, 4, 0, xcd))
best = [launch, (bm, 4, 0, xcd), c] # c: compiled winner (reused eager)
except Exception:
continue
if best is None:
Expand Down Expand Up @@ -1282,15 +1314,14 @@ def gemm_fp8_tensorwise_flydsl_kernel(
args = (
_as_i8_flat(a),
_as_i8_flat(b),
out.contiguous().view(-1),
out.contiguous(),
a_scale_v,
b_scale_v,
M,
N,
torch.cuda.current_stream(),
)
launch, _cfg = _autotune_tn_dispatch(args, M, N, K, cbsz, blgp, out_fp16)
_get_compiled_dense(launch, args)(*args)
_run_dense(_autotune_tn_dispatch(args, M, N, K, cbsz, blgp, out_fp16), args)
if trans_c:
return out.t().contiguous()
return out
Expand All @@ -1310,15 +1341,14 @@ def gemm_fp8_tensorwise_flydsl_kernel(
args = (
_as_i8_flat(a),
_as_i8_flat(b),
out.contiguous().view(-1),
out.contiguous(),
a_scale_v,
b_scale_v,
M,
N,
torch.cuda.current_stream(),
)
launch, _cfg = _autotune_nn_dispatch(args, M, N, K, cbsz, blgp, out_fp16)
_get_compiled_dense(launch, args)(*args)
_run_dense(_autotune_nn_dispatch(args, M, N, K, cbsz, blgp, out_fp16), args)
elif (not trans_a) and trans_b:
# NT native: A [M, K], B [N, K] (B^T storage of [K, N]).
M, K_a = a.shape
Expand All @@ -1333,15 +1363,14 @@ def gemm_fp8_tensorwise_flydsl_kernel(
args = (
_as_i8_flat(a),
_as_i8_flat(b),
out.contiguous().view(-1),
out.contiguous(),
a_scale_v,
b_scale_v,
M,
N,
torch.cuda.current_stream(),
)
launch, _cfg = _autotune_nt_dispatch(args, M, N, K, cbsz, blgp, out_fp16)
_get_compiled_dense(launch, args)(*args)
_run_dense(_autotune_nt_dispatch(args, M, N, K, cbsz, blgp, out_fp16), args)
else:
raise NotImplementedError(
f"FlyDSL fp8 GEMM does not support the TT layout " f"(trans_a={trans_a}, trans_b={trans_b})."
Expand Down
Empty file.
Loading
Loading