From 55eb955531e42df5fa81ba73caf6d7073cf8afde Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 6 May 2026 18:38:55 +0000 Subject: [PATCH 01/17] Messy iter commit --- benchmarks/profile_indexer.py | 202 ++++ benchmarks/profile_indexer_topk.py | 249 +++++ benchmarks/run_indexer_kernel.py | 117 +++ transformer_engine/jax/indexer.py | 319 +++++++ .../jax/pallas_kernels/__init__.py | 16 + .../jax/pallas_kernels/indexer.py | 271 ++++++ .../jax/triton_extensions/__init__.py | 5 + .../jax/triton_extensions/indexer.py | 859 ++++++++++++++++++ .../jax/triton_extensions/utils.py | 245 +++-- 9 files changed, 2230 insertions(+), 53 deletions(-) create mode 100644 benchmarks/profile_indexer.py create mode 100644 benchmarks/profile_indexer_topk.py create mode 100644 benchmarks/run_indexer_kernel.py create mode 100644 transformer_engine/jax/indexer.py create mode 100644 transformer_engine/jax/pallas_kernels/__init__.py create mode 100644 transformer_engine/jax/pallas_kernels/indexer.py create mode 100644 transformer_engine/jax/triton_extensions/indexer.py diff --git a/benchmarks/profile_indexer.py b/benchmarks/profile_indexer.py new file mode 100644 index 000000000..f0e019305 --- /dev/null +++ b/benchmarks/profile_indexer.py @@ -0,0 +1,202 @@ +"""Profile the low-rank lightning-indexer at realistic shapes. + +Measures wall time and effective TFLOPS for the einsum baseline vs the +fused Triton kernel. + +Run inside the container: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer.py' +""" + +import time + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.indexer import indexer, quantize_to_fp8 + +# Triton hybrid backend: einsum projections + Triton score-relu-reduce. +try: + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 + _HAVE_HYBRID = True +except Exception as _e: # noqa: BLE001 + _HAVE_HYBRID = False + _HYBRID_IMPORT_ERROR = _e + + +# --- Inputs / FLOP accounting ---------------------------------------------------- + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + # Learnable per-(token, indexer-head) weight projection: W_o = Q @ W_w. + W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, W_w + + +def make_fp8_inputs(B, oH, T, S, d, d_c, H, d_i, *, + fp8_dtype=jnp.float8_e4m3fn, weights_dtype=jnp.bfloat16, + seed=0): + """Sample bf16 tensors then quantize Q/K/W_uq/W_dq/W_k to FP8. + + W_w stays in ``weights_dtype`` (bf16) — the reference impl does not + dequantize it. + + Returns (Q, K, W_uq, W_dq, W_k, W_w, scales_dict). + """ + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16, seed=seed + ) + Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) + K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) + Wuq_q, swq = quantize_to_fp8(W_uq, dtype=fp8_dtype) + Wdq_q, swd = quantize_to_fp8(W_dq, dtype=fp8_dtype) + Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) + W_w = W_w.astype(weights_dtype) + scales = dict(scale_q=sq, scale_k=sk, + scale_wq=swq, scale_wd=swd, scale_wk=swk) + return Q_q, K_q, Wuq_q, Wdq_q, Wk_q, W_w, scales + + +def theoretical_flops(B, oH, T, S, d, d_c, H, d_i): + # 2 flops per multiply-add. Counts the contractions in the low-rank + # indexer with learnable output-weight projection: + # C_q = Q @ W_dq : 2 * B*oH * T * d_c * d + # H_q = einsum(C_q, W_uq) : 2 * B*oH * T * H * d_i * d_c + # H_k = K @ W_k : 2 * B*oH * S * d_i * d + # scores = relu(H_q @ H_k^T) : 2 * B*oH * T * H * S * d_i + # W_o = Q @ W_w : 2 * B*oH * T * d * H + # O = sum_h scores * W_o : 2 * B*oH * T * S * H + n = B * oH + return 2 * ( + n * T * d_c * d + + n * T * H * d_i * d_c + + n * S * d_i * d + + n * T * H * S * d_i + + n * T * d * H + + n * T * S * H + ) + + +def time_fn(fn, args, n_warmup=15, n_iter=50): + for _ in range(n_warmup): + out = fn(*args) + jax.block_until_ready(out) + t0 = time.perf_counter() + for _ in range(n_iter): + out = fn(*args) + jax.block_until_ready(out) + return (time.perf_counter() - t0) / n_iter + + +# --- Driver --------------------------------------------------------------------- + +CONFIGS = [ + #(B, oH, T, S, d, d_c, H, d_i, dtype) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128, jnp.bfloat16), +] + + +def _is_fp8(dt): + return jnp.dtype(dt) in ( + jnp.dtype("float8_e4m3fn"), jnp.dtype("float8_e5m2"), + jnp.dtype("float8_e4m3fnuz"), jnp.dtype("float8_e5m2fnuz"), + ) + + +def _bind_scales(fn, scales, *, backend=None): + """Return a 6-arg jit-able function that internally adds scale kwargs. + + If ``backend`` is given, it is forwarded as a kwarg to ``fn`` (used to + select between einsum / hybrid / pure-triton via the same ``indexer`` + entry point). + """ + extra = {} + if backend is not None: + extra["backend"] = backend + if scales is None and not extra: + return jax.jit(fn) + @jax.jit + def wrapped(Q, K, W_uq, W_dq, W_k, W_w): + kwargs = dict(extra) + if scales is not None: + kwargs.update(scales) + return fn(Q, K, W_uq, W_dq, W_k, W_w, **kwargs) + return wrapped + + +def _build_impls(scales): + impls = [ + ("baseline", _bind_scales(indexer, scales, backend="reference")), + ] + if _HAVE_HYBRID: + impls.append(("hybrid", _bind_scales(indexer, scales, backend="hybrid"))) + return impls + + +if not _HAVE_HYBRID: + print(f"[profile_indexer] Hybrid backend unavailable: {_HYBRID_IMPORT_ERROR}") + + +def _dump_autotuner_winner(): + """Print the autotuner-selected config(s) for _score_reduce_kernel.""" + if not _HAVE_HYBRID: + return + try: + from transformer_engine.jax.triton_extensions.indexer import ( + _score_reduce_kernel, + ) + except ImportError: + return + cache = getattr(_score_reduce_kernel, "cache", None) + if not cache: + print(" [autotune] no cache entries") + return + for key, cfg in cache.items(): + print(f" [autotune] key={key} -> {cfg}") + + +def main(): + print(f"jax devices: {jax.devices()}\n") + for cfg in CONFIGS: + B, oH, T, S, d, d_c, H, d_i, dtype = cfg + is_fp8 = _is_fp8(dtype) + if is_fp8: + Q, K, W_uq, W_dq, W_k, W_w, scales = make_fp8_inputs( + B, oH, T, S, d, d_c, H, d_i, fp8_dtype=dtype + ) + else: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, dtype + ) + scales = None + args = (Q, K, W_uq, W_dq, W_k, W_w) + impls = _build_impls(scales) + flops = theoretical_flops(B, oH, T, S, d, d_c, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} " + f"{dtype.dtype.name} ---") + print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call") + baseline_ms = None + for name, fn in impls: + try: + sec = time_fn(fn, args) + tflops = flops / sec / 1e12 + ms = sec * 1e3 + if name == "baseline": + baseline_ms = ms + speed = "" + else: + speed = f" ({baseline_ms/ms:.2f}x baseline)" + print(f" {name:<10} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") + except Exception as e: # noqa: BLE001 + print(f" {name:<10} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + _dump_autotuner_winner() + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py new file mode 100644 index 000000000..bc1e27d8d --- /dev/null +++ b/benchmarks/profile_indexer_topk.py @@ -0,0 +1,249 @@ +"""Benchmark fused indexer+topk vs reference (full score then jax.lax.top_k). + +Production config: B=4, H=16, T_t=T_s=4096, d=128, I=4, d_i=64, k=64, bf16. + +Sweeps (block_t, block_s, num_warps, num_stages) for the triton kernel and +reports TFLOP/s, ms, and vs-reference speedup. FLOPs counted as the underlying +indexer compute (top-k itself is comparison-only, treated as 0 FLOP). + +Usage: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer_topk.py' +""" + +import time +import functools + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.indexer import _indexer_impl_reference, quantize_to_fp8 +from transformer_engine.jax.triton_extensions.indexer import ( + indexer_fused_topk_triton, + indexer_fused_triton, +) +try: + from transformer_engine.jax.pallas_kernels.indexer import indexer_fused as _pallas_indexer + _HAVE_PALLAS = True +except Exception: + _pallas_indexer = None + _HAVE_PALLAS = False + + +_FP8_DTYPES = ( + jnp.dtype("float8_e4m3fn"), jnp.dtype("float8_e5m2"), + jnp.dtype("float8_e4m3fnuz"), jnp.dtype("float8_e5m2fnuz"), +) + + +def _is_fp8(dt): + return jnp.dtype(dt) in _FP8_DTYPES + + +def make_inputs(B, H, T_t, T_s, d, I, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 5) + Q = jax.random.normal(keys[0], (B, H, T_t, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, H, T_s, d), dtype=dtype) + W_q = jax.random.normal(keys[2], (I, d, d_i), dtype=dtype) + W_k = jax.random.normal(keys[3], (d, d_i), dtype=dtype) + weights = jax.random.normal(keys[4], (B, H, T_t, I), dtype=dtype) + return Q, K, W_q, W_k, weights + + +def make_fp8_inputs(B, H, T_t, T_s, d, I, d_i, *, fp8_dtype, seed=0): + Q, K, W_q, W_k, weights = make_inputs( + B, H, T_t, T_s, d, I, d_i, jnp.bfloat16, seed=seed + ) + Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) + K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) + Wq_q, swq = quantize_to_fp8(W_q, dtype=fp8_dtype) + Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) + return Q_q, K_q, Wq_q, Wk_q, weights, dict( + scale_q=sq, scale_k=sk, scale_wq=swq, scale_wk=swk, + ) + + +def theoretical_flops(B, H, T_t, T_s, d, I, d_i): + n = B * H + return 2 * ( + n * T_t * I * d_i * d + + n * T_s * d_i * d + + n * T_t * I * T_s * d_i + + n * T_t * T_s * I + ) + + +def time_fn(fn, args, n_warmup=5, n_iter=50): + for _ in range(n_warmup): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + t0 = time.perf_counter() + for _ in range(n_iter): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + return (time.perf_counter() - t0) / n_iter + + +# Reference, pallas+topk, triton+topk: each accepts an optional `scales` dict +# (None for high-precision). Built fresh per-config since the scales are baked +# into the closure. +def _make_reference_topk(scales): + if scales is None: + @jax.jit + def fn(Q, K, W_q, W_k, weights): + scores = _indexer_impl_reference(Q, K, W_q, W_k, weights) + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + else: + @jax.jit + def fn(Q, K, W_q, W_k, weights): + scores = _indexer_impl_reference(Q, K, W_q, W_k, weights, **scales) + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + return fn + + +def _make_pallas_then_topk(scales): + if not _HAVE_PALLAS: + return None + if scales is None: + @jax.jit + def fn(Q, K, W_q, W_k, weights): + scores = _pallas_indexer(Q, K, W_q, W_k, weights) + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + else: + @jax.jit + def fn(Q, K, W_q, W_k, weights): + scores = _pallas_indexer(Q, K, W_q, W_k, weights, **scales) + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + return fn + + +def _make_triton_then_topk(scales): + if scales is None: + @jax.jit + def fn(Q, K, W_q, W_k, weights): + scores = indexer_fused_triton(Q, K, W_q, W_k, weights) + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + else: + @jax.jit + def fn(Q, K, W_q, W_k, weights): + scores = indexer_fused_triton(Q, K, W_q, W_k, weights, **scales) + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + return fn + + +# Standalone: just time jax.lax.top_k on a precomputed score matrix. +@jax.jit +def topk_only(scores): + return jax.lax.top_k(scores, K_TOPK_GLOBAL) + + +def _make_triton(k, bt, bs, nw, ns): + fn = jax.jit(functools.partial( + indexer_fused_topk_triton, + k=k, block_t=bt, block_s=bs, num_warps=nw, num_stages=ns, + )) + return fn + + +CONFIGS = [ + # (B, H, T_t, T_s, d, I, d_i, dtype) + ( 4, 16, 2048, 2048, 128, 4, 64, jnp.bfloat16), + ( 4, 16, 4096, 4096, 128, 4, 64, jnp.bfloat16), + # FP8 e4m3 — fused-topk Triton kernel doesn't accept FP8 yet; the row will + # report "(skipped: fp8 not supported)" for that impl. The other three + # paths (reference, pallas+topk, triton+topk) all run end-to-end in FP8. + ( 4, 16, 2048, 2048, 128, 4, 64, jnp.float8_e4m3fn), + ( 4, 16, 4096, 4096, 128, 4, 64, jnp.float8_e4m3fn), +] + +K_TOPK_GLOBAL = 64 + +SWEEP = [ + # (block_t, block_s, num_warps, num_stages) + ( 64, 64, 4, 1), + ( 64, 64, 8, 1), + (128, 64, 4, 1), + (128, 64, 8, 1), + ( 32, 32, 4, 1), + ( 32, 64, 4, 1), + ( 32, 128, 4, 1), # k=64+128=192 not pow2; will be skipped + ( 64, 32, 4, 1), + (256, 64, 4, 1), + (256, 64, 8, 1), +] + + +def main(): + print(f"jax devices: {jax.devices()}\nk = {K_TOPK_GLOBAL}\n") + for cfg in CONFIGS: + B, H, T_t, T_s, d, I, d_i, dtype = cfg + is_fp8 = _is_fp8(dtype) + if is_fp8: + Q, K, W_q, W_k, weights, scales = make_fp8_inputs( + B, H, T_t, T_s, d, I, d_i, fp8_dtype=dtype + ) + args = (Q, K, W_q, W_k, weights) + else: + args = make_inputs(B, H, T_t, T_s, d, I, d_i, dtype) + scales = None + flops = theoretical_flops(B, H, T_t, T_s, d, I, d_i) + print(f"--- B={B} H={H} T_t={T_t} T_s={T_s} d={d} I={I} d_i={d_i} {dtype.dtype.name} ---") + print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call") + + impls = [ + ("ref(einsum+topk)", _make_reference_topk(scales)), + ("pallas+topk", _make_pallas_then_topk(scales)), + ("triton+topk", _make_triton_then_topk(scales)), + ] + + ref_ms = None + for name, fn in impls: + if fn is None: + continue + try: + sec = time_fn(fn, args) + ms = sec * 1e3 + tflops = flops / sec / 1e12 + if name == "pallas+topk": + ref_ms = ms + print(f" {name:<22} {ms:8.3f} ms {tflops:6.2f} TFLOP/s") + except Exception as e: # noqa: BLE001 + print(f" {name:<22} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + + # Time top_k alone (on pre-materialized scores). For FP8 inputs the + # reference dequantizes internally and returns a high-precision matrix. + try: + if scales is None: + scores_mat = _indexer_impl_reference(*args) + else: + scores_mat = _indexer_impl_reference(*args, **scales) + sec = time_fn(topk_only, (scores_mat,)) + print(f" {'(top_k alone)':<22} {sec*1e3:8.3f} ms") + except Exception as e: # noqa: BLE001 + print(f" (top_k alone) FAILED: {type(e).__name__}") + + # Fused-topk Triton kernel does not accept FP8 yet — skip the sweep + # for FP8 configs. + if is_fp8: + print(f" {'fused-topk triton':<22} (skipped: fp8 not supported by topk kernel)") + print() + continue + + # Triton fused-topk sweep (high-precision only) + for bt, bs, nw, ns in SWEEP: + if (K_TOPK_GLOBAL + bs) & (K_TOPK_GLOBAL + bs - 1) != 0: + continue # k+block_s must be pow2 + label = f"triton bt={bt} bs={bs} W={nw} S={ns}" + try: + fn = _make_triton(K_TOPK_GLOBAL, bt, bs, nw, ns) + sec = time_fn(fn, args) + ms = sec * 1e3 + tflops = flops / sec / 1e12 + speed = f" ({ref_ms/ms:.2f}x ref)" if ref_ms else "" + print(f" {label:<22} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") + except Exception as e: # noqa: BLE001 + print(f" {label:<22} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_indexer_kernel.py b/benchmarks/run_indexer_kernel.py new file mode 100644 index 000000000..d61a1738c --- /dev/null +++ b/benchmarks/run_indexer_kernel.py @@ -0,0 +1,117 @@ +"""Minimal direct invocation of the low-rank indexer kernel for profiling. + +No baselines, no comparisons. Just: build inputs once, jit the kernel, +warm it up, then run a fixed number of iterations under whatever +profiler is wrapping this process. + +Run inside the container: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/run_indexer_kernel.py' +""" + +import argparse +import time + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.indexer import quantize_to_fp8 +from transformer_engine.jax.triton_extensions.indexer import indexer_fused_triton as _triton_indexer + +_BACKENDS = { + "triton": _triton_indexer, +} + +_DTYPE_MAP = { + "bf16": jnp.bfloat16, + "fp32": jnp.float32, + "fp8": jnp.float8_e4m3fn, +} + + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + weights = jax.random.normal(keys[5], (B, oH, H, T), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, weights + + +def make_fp8_inputs(B, oH, T, S, d, d_c, H, d_i, *, fp8_dtype, seed=0): + """Quantize all five matrices to FP8; weights stay bf16.""" + Q, K, W_uq, W_dq, W_k, weights = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16, seed=seed + ) + Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) + K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) + Wuq_q, swq = quantize_to_fp8(W_uq, dtype=fp8_dtype) + Wdq_q, swd = quantize_to_fp8(W_dq, dtype=fp8_dtype) + Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) + return Q_q, K_q, Wuq_q, Wdq_q, Wk_q, weights, dict( + scale_q=sq, scale_k=sk, scale_wq=swq, scale_wd=swd, scale_wk=swk, + ) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--B", type=int, default=4) + p.add_argument("--oH", type=int, default=16, help="outer (multi-attn) heads") + p.add_argument("--T", type=int, default=2048) + p.add_argument("--S", type=int, default=2048) + p.add_argument("--d", type=int, default=512, help="hidden dim") + p.add_argument("--d_c", type=int, default=128, help="down-projection rank") + p.add_argument("--H", type=int, default=64, help="indexer-head count") + p.add_argument("--d_i", type=int, default=128, help="per-indexer-head dim") + p.add_argument("--dtype", choices=list(_DTYPE_MAP), default="bf16") + p.add_argument("--warmup", type=int, default=5) + p.add_argument("--iters", type=int, default=50) + p.add_argument("--backend", choices=list(_BACKENDS), default="triton") + args = p.parse_args() + + dtype = _DTYPE_MAP[args.dtype] + is_fp8 = args.dtype == "fp8" + print(f"jax devices: {jax.devices()}") + print(f"shape: B={args.B} oH={args.oH} T={args.T} S={args.S} " + f"d={args.d} d_c={args.d_c} H={args.H} d_i={args.d_i} " + f"dtype={args.dtype} backend={args.backend}") + + if is_fp8: + Q, K, W_uq, W_dq, W_k, weights, scales = make_fp8_inputs( + args.B, args.oH, args.T, args.S, + args.d, args.d_c, args.H, args.d_i, fp8_dtype=dtype, + ) + inputs = (Q, K, W_uq, W_dq, W_k, weights) + else: + inputs = make_inputs(args.B, args.oH, args.T, args.S, + args.d, args.d_c, args.H, args.d_i, dtype) + scales = None + + raw_fn = _BACKENDS[args.backend] + + if scales is None: + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, weights): + return raw_fn(Q, K, W_uq, W_dq, W_k, weights) + else: + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, weights): + return raw_fn(Q, K, W_uq, W_dq, W_k, weights, **scales) + + # Warmup: triggers JIT compile + first-launch overhead. + for _ in range(args.warmup): + out = fn(*inputs) + jax.block_until_ready(out) + + # Timed region: this is what the profiler should focus on. + t0 = time.perf_counter() + for _ in range(args.iters): + out = fn(*inputs) + jax.block_until_ready(out) + sec = (time.perf_counter() - t0) / args.iters + print(f"avg per call: {sec*1e3:.3f} ms ({args.iters} iters)") + + +if __name__ == "__main__": + main() diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/indexer.py new file mode 100644 index 000000000..141e14275 --- /dev/null +++ b/transformer_engine/jax/indexer.py @@ -0,0 +1,319 @@ +"""Indexer op (forward only). + +Two backends: + * "reference" - jnp/einsum, accepts arbitrary leading dims (..., T, d). + * "fused" - Pallas kernel, strict BHSD (B, H, T, d). Lives in + transformer_engine/jax/pallas_kernels/indexer.py. + +Top-level entry point: ``indexer(Q, K, W_uq, W_dq, W_k, weights, *, backend=...)``. + +Math (low-rank form: Q is hidden state; query heads are produced by a +down-projection (d -> d_c) followed by an up-projection (d_c -> H * d_i)): + + C_q = Q @ W_dq # (..., T, d_c) + H_q = einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) + H_k = K @ W_k # (..., S, d_i) + H = relu(einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) + O = einsum("...ths,...ht->...ts", H, weights) # (..., T, S) + +``weights`` is the precomputed per-(indexer-head, token) weight (DeepSeek's +``weights_proj(x)`` term, transposed for kernel-friendly layout). Its leading +dims must broadcast against Q's. + +FP8 mode: any of Q / K / W_uq / W_dq / W_k may be FP8 (e4m3) tensors. Each +FP8 operand needs a per-tensor fp32 scale (scale_q, scale_k, scale_wq, +scale_wd, scale_wk). ReLU commutes with positive scaling so the active +scales fold into a single fp32 scalar applied once at the end. Letting W_dq +go FP8 unlocks a native FP8 MFMA on the Q @ W_dq down-projection (and saves +half the bytes for that weight) at the cost of additional quantization noise +in the bottleneck of the low-rank decomposition. +""" + +import functools +import math + +import jax +import jax.numpy as jnp + + +_FP8_DTYPES = frozenset([ + jnp.dtype("float8_e4m3fn"), + jnp.dtype("float8_e5m2"), + jnp.dtype("float8_e4m3fnuz"), + jnp.dtype("float8_e5m2fnuz"), +]) + + +def _is_fp8(x): + return jnp.dtype(x.dtype) in _FP8_DTYPES + + +def quantize_to_fp8(x, *, dtype=None, axis=None): + """Per-tensor amax-based quantization helper (for tests/profiling). + + Returns (x_fp8, scale_fp32) where the dequantization is ``x_fp8 * scale``. + """ + if dtype is None: + dtype = jnp.float8_e4m3fn + fp8_max = jnp.finfo(dtype).max.astype(jnp.float32) + amax = jnp.max(jnp.abs(x.astype(jnp.float32))) if axis is None else \ + jnp.max(jnp.abs(x.astype(jnp.float32)), axis=axis, keepdims=True) + scale = (amax / fp8_max).astype(jnp.float32) + # avoid divide-by-zero on all-zero tensors + scale = jnp.where(scale == 0, jnp.float32(1.0), scale) + x_fp8 = (x.astype(jnp.float32) / scale).astype(dtype) + return x_fp8, scale + + +# --- Reference implementation --------------------------------------------------- + +def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, + scale_q=None, scale_k=None, + scale_wq=None, scale_wd=None, scale_wk=None, + out_dtype=None): + """ + Q [..., T, d] + K [..., S, d] + W_dq [d, d_c] + W_uq [H, d_c, d_i] + W_k [d, d_i] + W_w [..., d, H] # leading dims must match Q's + + FP8 path: each fp8 operand is dequantized via cast-to-bf16-then-multiply + immediately before the matmul that consumes it. This is the pattern XLA's + GEMM rewriter recognizes and lowers to ``__cublas$lt$matmul$f8`` (native + fp8 hardware GEMM) for matmuls where both operands are originally fp8. + Upcasting to fp32 first would lose the fp8 type info and fall back to + plain fp32 GEMM — strictly worse. + """ + if _is_fp8(Q): + if any(s is None for s in (scale_q, scale_k, scale_wq, scale_wk)): + raise ValueError( + "FP8 reference requires scale_q, scale_k, scale_wq, scale_wk." + ) + if _is_fp8(W_dq) and scale_wd is None: + raise ValueError("FP8 W_dq requires scale_wd.") + + wp = jnp.bfloat16 # working precision for non-fp8 intermediates + + def _dq(x, s): + # cast-then-scale pattern (in working precision, NOT fp32). XLA's + # GEMM rewriter pulls (cast, multiply, dot) into a fused fp8 GEMM + # when both operands of the dot follow this pattern. + if _is_fp8(x): + return x.astype(wp) * jnp.float32(s).astype(wp) + return x.astype(wp) + + Q_d = _dq(Q, scale_q) + K_d = _dq(K, scale_k) + W_uq_d = _dq(W_uq, scale_wq) + W_dq_d = _dq(W_dq, scale_wd) + W_k_d = _dq(W_k, scale_wk) + + C_q = jnp.einsum("...td,dc->...tc", Q_d, W_dq_d) # (..., T, d_c) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq_d) # (..., T, H, d_i) + H_k = jnp.einsum("...sd,di->...si", K_d, W_k_d) # (..., S, d_i) + H = jax.nn.relu(jnp.einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) + W_o = jnp.einsum("...td,dh->...th", Q_d, W_w) + O = jnp.einsum("...ths,...th->...ts", H, W_o) # (..., T, S) + if out_dtype is not None: + O = O.astype(out_dtype) + return O + + +# --- Fused implementation (Pallas) ---------------------------------------------- +# Imported lazily so callers without Triton/Pallas can still use the reference. + +def _indexer_impl_fused(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs): + raise NotImplementedError( + "Pallas backend has not yet been updated for the low-rank indexer form " + "(W_uq + W_dq). Use backend='triton' or backend='reference'." + ) + + +def _indexer_impl_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs): + from transformer_engine.jax.triton_extensions.indexer import indexer_fused_triton + return indexer_fused_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + + +def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, + scale_q=None, scale_k=None, + scale_wq=None, scale_wd=None, scale_wk=None, + out_dtype=None): + """Einsum projections + Triton score-relu-reduce. + + Mirrors ``_indexer_impl_reference`` for the four projections (which + lower to hipBLASLt GEMMs), then hands Hq / Hk / W_o to a fused Triton + kernel that does score+relu+H-reduction in registers — eliminating the + 16+ GB pre-relu-score HBM round-trip the pure-einsum path pays. + + bf16 only for now. FP8 inputs are dequantized to bf16 just like the + reference; native FP8 GEMM is not available on ROCm anyway. + """ + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton + + if _is_fp8(Q): + if any(s is None for s in (scale_q, scale_k, scale_wq, scale_wk)): + raise ValueError( + "FP8 hybrid requires scale_q, scale_k, scale_wq, scale_wk." + ) + if _is_fp8(W_dq) and scale_wd is None: + raise ValueError("FP8 W_dq requires scale_wd.") + + wp = jnp.bfloat16 + + def _dq(x, s): + if _is_fp8(x): + return x.astype(wp) * jnp.float32(s).astype(wp) + return x.astype(wp) + + Q_d = _dq(Q, scale_q) + K_d = _dq(K, scale_k) + W_uq_d = _dq(W_uq, scale_wq) + W_dq_d = _dq(W_dq, scale_wd) + W_k_d = _dq(W_k, scale_wk) + + C_q = jnp.einsum("...td,dc->...tc", Q_d, W_dq_d) # (..., T, d_c) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq_d) # (..., T, H, d_i) + H_k = jnp.einsum("...sd,di->...si", K_d, W_k_d) # (..., S, d_i) + W_o = jnp.einsum("...td,dh->...th", Q_d, W_w.astype(wp)) # (..., T, H) + + O = score_reduce_triton(H_q, H_k, W_o, + out_dtype=out_dtype if out_dtype else wp) + return O + + +def _indexer_topk_impl_reference(Q, K, W_uq, W_dq, W_k, weights, k): + scores = _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights) + return jax.lax.top_k(scores, k) + + +def _indexer_topk_impl_triton(Q, K, W_uq, W_dq, W_k, weights, k): + from transformer_engine.jax.triton_extensions.indexer import indexer_fused_topk_triton + return indexer_fused_topk_triton(Q, K, W_uq, W_dq, W_k, weights, k=k) + + +@functools.partial(jax.jit, static_argnames=("k", "backend")) +def indexer_topk(Q, K, W_uq, W_dq, W_k, weights, *, k, backend="triton"): + """Indexer fused with per-row top-k along T_s. + + Returns (vals, idxs): + vals: (..., T, k) Q.dtype + idxs: (..., T, k) int32 + + backend: "reference" (full score then jax.lax.top_k) or "triton" (fused). + """ + if backend == "reference": + return _indexer_topk_impl_reference(Q, K, W_uq, W_dq, W_k, weights, k) + if backend == "triton": + return _indexer_topk_impl_triton(Q, K, W_uq, W_dq, W_k, weights, k) + raise ValueError(f"unknown backend {backend!r}; expected 'reference' or 'triton'") + + +# --- Top-level dispatch --------------------------------------------------------- + +@functools.partial(jax.jit, static_argnames=("backend", "out_dtype")) +def indexer(Q, K, W_uq, W_dq, W_k, weights, *, + scale_q=None, scale_k=None, + scale_wq=None, scale_wd=None, scale_wk=None, + out_dtype=None, backend="reference"): + """Low-rank lightning-indexer. + + Args: + Q: (..., T, d) hidden state (per token) + K: (..., S, d) key hidden state + W_uq: (H, d_c, d_i) up-projection: d_c -> d_i (per head) + W_dq: (d, d_c) down-projection: d -> d_c + W_k: (d, d_i) key projection + weights: (..., H, T) per-(indexer-head, token) weight + scale_q, scale_k, scale_wq, scale_wk: + per-tensor fp32 dequant scales. Required when Q is FP8. + scale_wd: + per-tensor fp32 dequant scale for W_dq. Required only when + W_dq itself is FP8. + out_dtype: output dtype override (defaults to Q.dtype, or weights.dtype + in FP8 mode). + backend: "reference", "fused" (Pallas), or "triton". + + Returns: + O of shape (..., T, S). + """ + fp8_kwargs = dict( + scale_q=scale_q, scale_k=scale_k, + scale_wq=scale_wq, scale_wd=scale_wd, scale_wk=scale_wk, + out_dtype=out_dtype, + ) + if backend == "reference": + return _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + if backend == "fused": + return _indexer_impl_fused(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + if backend == "triton": + return _indexer_impl_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + if backend == "hybrid": + return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + raise ValueError( + f"unknown backend {backend!r}; expected 'reference', 'fused', 'triton', " + f"or 'hybrid'" + ) + + +# --- Tests ---------------------------------------------------------------------- + +def _reference_nobatch(Q, K, W_uq, W_dq, W_k, weights): + """Rank-2 reference (no leading dims) used as the cross-check.""" + C_q = Q @ W_dq + H_q = jnp.einsum("tc,hci->thi", C_q, W_uq) + H_k = K @ W_k + H = jax.nn.relu(jnp.einsum("thi,si->ths", H_q, H_k)) + return jnp.einsum("ths,ht->ts", H, weights) + + +def _run_test(leading_shape, seed, backend): + # Power-of-2 shapes, all matmul dims >= 16 so the Pallas backend accepts. + T_t, T_s, d, d_c, H, d_i = 16, 16, 16, 16, 4, 16 + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (*leading_shape, T_t, d)) + K = jax.random.normal(keys[1], (*leading_shape, T_s, d)) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i)) + W_dq = jax.random.normal(keys[3], (d, d_c)) + W_k = jax.random.normal(keys[4], (d, d_i)) + weights = jax.random.normal(keys[5], (*leading_shape, H, T_t)) + + try: + O = indexer(Q, K, W_uq, W_dq, W_k, weights, backend=backend) + except Exception as e: # noqa: BLE001 + print(f" backend={backend:<10s} leading={str(leading_shape):10s} " + f"SKIP: {type(e).__name__}: {str(e).splitlines()[0]}") + return + + flat = math.prod(leading_shape) if leading_shape else 1 + Q_f = Q.reshape(flat, T_t, d) + K_f = K.reshape(flat, T_s, d) + weights_f = weights.reshape(flat, T_t, H) + O_ref = jax.vmap(lambda q, k, w: _reference_nobatch(q, k, W_uq, W_dq, W_k, w))( + Q_f, K_f, weights_f + ) + O_ref = O_ref.reshape(*leading_shape, T_t, T_s) + + expected_shape = (*leading_shape, T_t, T_s) + shape_ok = O.shape == expected_shape + max_err = float(jnp.max(jnp.abs(O - O_ref))) + tag = "OK" if shape_ok and max_err < 1e-4 else "FAIL" + print(f" backend={backend:<10s} leading={str(leading_shape):10s} " + f"O.shape={O.shape} max abs err={max_err:.2e} [{tag}]") + + +if __name__ == "__main__": + print("=== reference backend ===") + for i, leading in enumerate([(), (2,), (2, 3)]): + _run_test(leading, seed=i, backend="reference") + + # Fused: strictly BHSD (rank-4 Q/K), so only the (2, 3) case applies. + print("\n=== fused backend ===") + for i, leading in enumerate([(), (2,), (2, 3)]): + _run_test(leading, seed=100 + i, backend="fused") + + # Triton: same BHSD-only contract as fused. + print("\n=== triton backend ===") + for i, leading in enumerate([(), (2,), (2, 3)]): + _run_test(leading, seed=200 + i, backend="triton") diff --git a/transformer_engine/jax/pallas_kernels/__init__.py b/transformer_engine/jax/pallas_kernels/__init__.py new file mode 100644 index 000000000..7cdb7834b --- /dev/null +++ b/transformer_engine/jax/pallas_kernels/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Pallas-backed kernels for Transformer Engine JAX. + +Pallas is the JAX-native kernel-authoring API. Its Triton lowering uses the +new ``__gpu$xla.gpu.triton`` custom call, which works on both NVIDIA and +AMD/ROCm. The kernel body looks essentially like a Triton kernel - same +program_id / block-pointer model - but uses ``pl.*`` primitives instead of +``tl.*`` so it composes cleanly with JAX (jit, sharding, dtype rules). +""" + +from .indexer import indexer_fused + +__all__ = ["indexer_fused"] diff --git a/transformer_engine/jax/pallas_kernels/indexer.py b/transformer_engine/jax/pallas_kernels/indexer.py new file mode 100644 index 000000000..7ee3508cc --- /dev/null +++ b/transformer_engine/jax/pallas_kernels/indexer.py @@ -0,0 +1,271 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Pallas kernel for the fused indexer (BHSD layout). + +Reference math (see transformer_engine.jax.indexer._indexer_impl_reference): + + H_q = einsum("bhtd,dij->bhtij", Q, W_q) # (B, H, T_t, I, d_i) + H_k = K @ W_k # (B, H, T_s, d_i) + H = relu(einsum("bhtij,bhsj->bhtis", H_q, H_k)) # (B, H, T_t, I, T_s) + O = einsum("bhtis,bhti->bhts", H, weights) # (B, H, T_t, T_s) + +``weights`` is precomputed per-(token, indexer-head) weighting with shape +matching Q's leading dims (B, H, T_t, I). It plays the role DeepSeek's +``weights_proj(x)`` plays in their lightning-indexer: a learned, data- +dependent per-head weight, not a static parameter. + +The launcher (``indexer_fused`` at the bottom of this file) wires up grid +size, BlockSpecs, and shape inference. The kernel body itself +(``_indexer_pallas_kernel_body``) is the part you fill in. +""" + +import jax +import jax.numpy as jnp +from jax.experimental import pallas as pl + + +# Preferred tile sizes (used when the LDS budget allows). Pallas-Triton stages +# Q and K tiles to LDS via BlockSpec, so the dominant LDS cost is +# (BLOCK_T + BLOCK_S) * d * dtype_bytes. Auto-shrink keeps that under +# _LDS_BUDGET_BYTES per workgroup. +_PREFERRED_BLOCK_T = 128 +_PREFERRED_BLOCK_S = 64 + +# gfx950 reports 160 KB LDS; leave ~10% headroom for compiler-staged +# intermediates (Hk, Hq_i, accumulators). +_LDS_BUDGET_BYTES = 144 * 1024 +_MIN_BLOCK = 16 # Triton requires matmul dims >= 16 + +_FP8_DTYPES = frozenset([ + jnp.dtype("float8_e4m3fn"), + jnp.dtype("float8_e5m2"), + jnp.dtype("float8_e4m3fnuz"), + jnp.dtype("float8_e5m2fnuz"), +]) + + +def _is_fp8_dtype(dt): + return jnp.dtype(dt) in _FP8_DTYPES + + +def _pick_tiles(T_t, T_s, d, dtype): + """Return (BLOCK_T, BLOCK_S) that fit the LDS budget for Q+K staging. + + Halve BLOCK_T first (less critical for inner-loop reuse), then BLOCK_S. + """ + elem_bytes = jnp.dtype(dtype).itemsize + bt = min(_PREFERRED_BLOCK_T, T_t) + bs = min(_PREFERRED_BLOCK_S, T_s) + + def cost(bt, bs): + return (bt + bs) * d * elem_bytes + + while cost(bt, bs) > _LDS_BUDGET_BYTES and bt > _MIN_BLOCK: + bt //= 2 + while cost(bt, bs) > _LDS_BUDGET_BYTES and bs > _MIN_BLOCK: + bs //= 2 + return bt, bs + + +def _estimate_lds_bytes(BLOCK_T, BLOCK_S, d, d_i, dtype): + """Worst-case LDS estimate. The dominant cost is usually K_tile + W_q[i] + slice (Pallas-Triton stages the per-iteration W_q[i] of shape (d, d_i)). + """ + elem_bytes = jnp.dtype(dtype).itemsize + k_tile = BLOCK_S * d * elem_bytes + q_tile = BLOCK_T * d * elem_bytes + w_q_slice = d * d_i * elem_bytes + # The two pairs that have actually been observed empirically: + return max(k_tile + w_q_slice, q_tile + w_q_slice) + + +class PallasIndexerInfeasible(RuntimeError): + """Raised when no valid (BLOCK_T, BLOCK_S) fits the LDS budget for the + given (d, d_i, dtype). The W_q[i] slice (size d*d_i*dtype_bytes) is the + typical culprit; mitigation requires d-tiling the inner matmul.""" + + +def _dot_fp32(a, b): + """`jnp.dot` with the fp32 accumulator made explicit. + + Without `preferred_element_type`, JAX promotion picks the input dtype as + the dot output dtype. For FP8 inputs that means the accumulated dot is + clamped to fp8 max (~448) BEFORE the fp32 cast — so any real workload + silently saturates. Force fp32 accumulation everywhere. + """ + return jax.lax.dot_general( + a, b, (((a.ndim - 1,), (0,)), ((), ())), + preferred_element_type=jnp.float32, + ) + + +def _make_kernel_body(BLOCK_T, BLOCK_S, d, I, d_i, is_fp8): + """Closure that bakes the static shape constants into the kernel body. + + Pallas kernel bodies trace under jit, so values referenced by Python-level + control flow (``range(I)`` etc.) must be static. The simplest way to make + them static is to capture them in a closure here. + + For FP8 inputs, the outer two dots (K@W_k, Q@W_q[i]) consume FP8 directly + via _dot_fp32; their fp32 outputs are downcast to bf16 for the inner + (Hq_i @ Hk^T) matmul. The combined per-tensor dequant scale is applied + to the fp32 accumulator at the very end; ReLU commutes with positive + scaling so this is exact. + """ + inter_dtype = jnp.bfloat16 if is_fp8 else None # None = preserve dtype + + def _indexer_pallas_kernel_body( + Q_ref, # (1, 1, BLOCK_T, d) - one (b, h) slice, T_t-tile + K_ref, # (1, 1, BLOCK_S, d) - one (b, h) slice, T_s-tile + W_q_ref, # (I, d, d_i) - whole tensor, replicated + W_k_ref, # (d, d_i) - whole tensor, replicated + weights_ref, # (1, 1, BLOCK_T, I) - one (b, h) slice, T_t-tile + scale_ref, # (1,) - combined fp32 scale + O_ref, # (1, 1, BLOCK_T, BLOCK_S) - one tile of the output + ): + """ + Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h). + """ + Q = Q_ref[0, 0] # (BLOCK_T, d) + K = K_ref[0, 0] # (BLOCK_S, d) + Wk = W_k_ref[...] # (d, d_i) + Hk = _dot_fp32(K, Wk) # (BLOCK_S, d_i) + if inter_dtype is not None: + Hk = Hk.astype(inter_dtype) + + acc = jnp.zeros((BLOCK_T, BLOCK_S), dtype=jnp.float32) + for i in range(I): + Wq_i = W_q_ref[i] # (d, d_i) + Hq_i = _dot_fp32(Q, Wq_i) # (BLOCK_T, d_i) + if inter_dtype is not None: + Hq_i = Hq_i.astype(inter_dtype) + Hi = jax.nn.relu(_dot_fp32(Hq_i, Hk.T)) # (BLOCK_T, BLOCK_S) + w_i = weights_ref[0, 0, :, i] # (BLOCK_T,) + acc = acc + Hi * w_i[:, None] + + acc = acc * scale_ref[0] + O_ref[0, 0] = acc.astype(O_ref.dtype) + + return _indexer_pallas_kernel_body + + +def indexer_fused( + Q, K, W_q, W_k, weights, + *, + scale_q=None, scale_k=None, scale_wq=None, scale_wk=None, + out_dtype=None, +): + """Pallas-backed fused indexer. Strict BHSD. + + Args: + Q: (B, H, T_t, d) high-precision (bf16/fp32) or FP8 e4m3 + K: (B, H, T_s, d) must match Q's dtype + W_q: (I, d, d_i) must match Q's dtype + W_k: (d, d_i) must match Q's dtype + weights: (B, H, T_t, I) high-precision regardless of Q dtype + scale_q, scale_k, scale_wq, scale_wk: + per-tensor fp32 dequant scales. Required when Q is FP8. + out_dtype: defaults to Q.dtype for non-FP8, weights.dtype for FP8. + + Returns: + O: (B, H, T_t, T_s) + """ + if Q.ndim != 4 or K.ndim != 4 or weights.ndim != 4: + raise ValueError( + f"indexer_fused (pallas) expects rank-4 BHSD Q, K and weights. Got " + f"Q.shape={Q.shape}, K.shape={K.shape}, weights.shape={weights.shape}. " + "Reshape (or add singleton head/batch axes) before calling the fused path." + ) + + B, H, T_t, d = Q.shape + Bk, Hk, T_s, dk = K.shape + I, d2, d_i = W_q.shape + d3, d_i_k = W_k.shape + Bw, Hw, T_t_w, I_w = weights.shape + + if (Bk, Hk) != (B, H): + raise ValueError(f"(B,H) mismatch: Q has {(B, H)}, K has {(Bk, Hk)}") + if not (d == dk == d2 == d3): + raise ValueError(f"d mismatch across Q/K/W_q/W_k: {d}, {dk}, {d2}, {d3}") + if d_i != d_i_k: + raise ValueError(f"d_i mismatch: W_q has {d_i}, W_k has {d_i_k}") + if (Bw, Hw, T_t_w, I_w) != (B, H, T_t, I): + raise ValueError( + f"weights shape {weights.shape} does not match expected " + f"(B={B}, H={H}, T_t={T_t}, I={I})" + ) + + is_fp8 = _is_fp8_dtype(Q.dtype) + if is_fp8: + for nm, t in (("K", K), ("W_q", W_q), ("W_k", W_k)): + if t.dtype != Q.dtype: + raise ValueError( + f"FP8 mode requires Q/K/W_q/W_k all match dtype; " + f"Q is {Q.dtype} but {nm} is {t.dtype}." + ) + if any(s is None for s in (scale_q, scale_k, scale_wq, scale_wk)): + raise ValueError( + "FP8 mode requires scale_q, scale_k, scale_wq, scale_wk." + ) + scale_combined = jnp.asarray( + jnp.float32(scale_q) * jnp.float32(scale_k) + * jnp.float32(scale_wq) * jnp.float32(scale_wk), + dtype=jnp.float32, + ).reshape((1,)) + if out_dtype is None: + out_dtype = weights.dtype + else: + scale_combined = jnp.asarray(1.0, dtype=jnp.float32).reshape((1,)) + if out_dtype is None: + out_dtype = Q.dtype + + BLOCK_T, BLOCK_S = _pick_tiles(T_t, T_s, d, Q.dtype) + lds = _estimate_lds_bytes(BLOCK_T, BLOCK_S, d, d_i, Q.dtype) + if lds > _LDS_BUDGET_BYTES: + raise PallasIndexerInfeasible( + f"Pallas indexer infeasible for this config: estimated LDS " + f"{lds // 1024} KB > budget {_LDS_BUDGET_BYTES // 1024} KB. " + f"Dominant cost is W_q[i] slice = d*d_i*dtype = " + f"{d * d_i * jnp.dtype(Q.dtype).itemsize // 1024} KB. " + f"Mitigation: d-tile the inner matmul (not implemented). " + f"For this config use the Triton backend instead." + ) + + grid = (B * H, pl.cdiv(T_t, BLOCK_T), pl.cdiv(T_s, BLOCK_S)) + + # BlockSpecs: each input/output is sliced based on (program_id_0, + # program_id_1, program_id_2). index_map returns the *block index* per + # axis (Pallas multiplies by block_shape internally). + def q_idx(bh, tt, ts): return (bh // H, bh % H, tt, 0) + def k_idx(bh, tt, ts): return (bh // H, bh % H, ts, 0) + def wq_idx(bh, tt, ts): return (0, 0, 0) + def wk_idx(bh, tt, ts): return (0, 0) + def weights_idx(bh, tt, ts): return (bh // H, bh % H, tt, 0) + def scale_idx(bh, tt, ts): return (0,) + def o_idx(bh, tt, ts): return (bh // H, bh % H, tt, ts) + + in_specs = [ + pl.BlockSpec(block_shape=(1, 1, BLOCK_T, d), index_map=q_idx), + pl.BlockSpec(block_shape=(1, 1, BLOCK_S, d), index_map=k_idx), + pl.BlockSpec(block_shape=(I, d, d_i), index_map=wq_idx), + pl.BlockSpec(block_shape=(d, d_i), index_map=wk_idx), + pl.BlockSpec(block_shape=(1, 1, BLOCK_T, I), index_map=weights_idx), + pl.BlockSpec(block_shape=(1,), index_map=scale_idx), + ] + out_spec = pl.BlockSpec( + block_shape=(1, 1, BLOCK_T, BLOCK_S), + index_map=o_idx, + ) + out_shape = jax.ShapeDtypeStruct((B, H, T_t, T_s), out_dtype) + + kernel_body = _make_kernel_body(BLOCK_T, BLOCK_S, d, I, d_i, is_fp8) + + return pl.pallas_call( + kernel_body, + grid=grid, + in_specs=in_specs, + out_specs=out_spec, + out_shape=out_shape, + )(Q, K, W_q, W_k, weights, scale_combined) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index d9708fde9..79a4dd733 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -58,3 +58,8 @@ def lowering(ctx, x, **kwargs): from .utils import * from .permutation import * +from .indexer import ( + indexer_fused_triton, + indexer_fused_topk_triton, + score_reduce_triton, +) diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py new file mode 100644 index 000000000..22f8de2ec --- /dev/null +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -0,0 +1,859 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Raw-Triton low-rank indexer kernel + JAX primitive. + +Math (matches the reference in transformer_engine.jax.indexer): + + C_q = Q @ W_dq # (..., T, d_c) + H_q = einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) + H_k = K @ W_k # (..., S, d_i) + H = relu(einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) + O = einsum("...ths,...ht->...ts", H, weights) # (..., T, S) + +Q is the hidden state (rank-4 BHSD: B × outer-H × T × d). W_dq is a low-rank +down-projection (d → d_c) and W_uq is the per-(indexer-head) up-projection. +The kernel loops over indexer heads internally; the outer (B, outer-H) dims +are flattened into the grid's first axis. + +FP8 mode: Q / K / W_uq / W_dq / W_k are all FP8 e4m3 (same dtype). The +five per-tensor scales (scale_q, scale_k, scale_wq, scale_wd, scale_wk) +fold into a single fp32 scalar applied at the end (ReLU is scale-invariant +under positive scaling). Three intermediate amax-based re-quantizations +(Cq, Hk, Hq per-head) keep the inner matmuls in fp8 too. +""" + +import functools + +import jax +import jax.numpy as jnp +import triton +import triton.language as tl + +from jax import core +from jax.extend import core as extend_core +from jax.interpreters import mlir, xla + +from .utils import triton_call_lowering + + +# Autotune sweep: BLOCK_T × BLOCK_S × num_warps × num_stages. Profiling +# showed num_warps=4 with the prior default (BLOCK_T=128) saturated VGPR +# (256/thread), forcing 1 wave/SIMD; smaller tiles or num_warps=8 cut VGPR +# in half and gave a 6× speedup at the d=512 fp8 config. Each config below +# launches at its own grid (cdiv(T_t, BLOCK_T) × cdiv(T_s, BLOCK_S)) — the +# triton_call_lowering helper supports per-config grids via a callable +# `grid` argument. +def _autotune_configs(): + configs = [] + for block_t in (16, 32): + for block_s in (16, 32): + for block_d in (16, 32): + for num_warps in (4, 8): + for num_stages in (1, 2): + configs.append(triton.Config( + {"BLOCK_T": block_t, "BLOCK_S": block_s, + "BLOCK_D": block_d}, + num_warps=num_warps, num_stages=num_stages, + )) + return configs + +_AUTOTUNE_CONFIGS = _autotune_configs() +# Re-run the benchmark when any of these constexprs change. T_t/T_s only +# affect grid size; their optimal config is dominated by per-CTA shape and +# the precision (IS_FP8). +_AUTOTUNE_KEY = ["IS_FP8", "d", "d_c", "H", "d_i"] + + +# Max representable value of FP8 e4m3 (used for per-tile inter-quantization). +# Triton requires module-level constants referenced inside @jit kernels to be +# wrapped in tl.constexpr explicitly. +_FP8_E4M3_MAX = tl.constexpr(448.0) +# Floor on per-tile amax to avoid divide-by-zero when a tile is all-zero. +_FP8_AMAX_EPS = tl.constexpr(1e-30) + + +@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=_AUTOTUNE_KEY) +@triton.jit +def _indexer_kernel( + Q_ptr, + K_ptr, + W_uq_ptr, # (H, d_c, d_i) - replicated across (B, oH) + W_dq_ptr, # (d, d_c) - replicated; same dtype as Q + W_k_ptr, # (d, d_i) - replicated + weights_ptr, + scale_ptr, # 0-D fp32 tensor: combined scale sq*sk*swq*swk (1.0 if non-FP8) + O_ptr, + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + d: tl.constexpr, + d_c: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_FP8: tl.constexpr, +): + """Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h_outer) slice. + + Grid: (B * oH, cdiv(T_t, BLOCK_T), cdiv(T_s, BLOCK_S)) + + Pipeline: + C_q = Q @ W_dq (down-projection, d-tiled) + Hk = K @ W_k (key projection, d-tiled) + for h in range(H): (loop over indexer heads) + Hq = C_q @ W_uq[:, h, :] (per-head up-projection) + Hi = relu(Hq @ Hk^T) (per-head score) + acc += Hi * weights[:, h] (weighted accumulate) + + The two d-contracting GEMMs (Q@W_dq and K@W_k) are tiled along d in + chunks of BLOCK_D. This keeps the W_dq / W_k tiles loaded into LDS at + BLOCK_D × {d_c, d_i} instead of d × {d_c, d_i}, freeing registers / + LDS for the inner per-head loop. + + FP8 mode (IS_FP8=True): all five matrices share the fp8 dtype. Every + MFMA is native fp8: the d-tiled Q@W_dq and K@W_k dots, then the inner + C_q@W_uq[h] and Hq@Hk^T dots after per-tile amax re-quantization of + Cq/Hk/Hq. The per-tile amax scales fold into the accumulator (Hq inside + the loop, Cq/Hk after) along with the user's combined per-tensor scale. + """ + pid_bh = tl.program_id(0) + pid_t = tl.program_id(1) + pid_s = tl.program_id(2) + + b = pid_bh // oH + h_outer = pid_bh % oH + + rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) + rdc = tl.arange(0, d_c) + rdi = tl.arange(0, d_i) + + rt_mask = rt < T_t + rs_mask = rs < T_s + + in_dtype = Q_ptr.dtype.element_ty + q_base = b * (oH * T_t * d) + h_outer * (T_t * d) + k_base = b * (oH * T_s * d) + h_outer * (T_s * d) + + # d-tiled accumulators for Q @ W_dq → (BLOCK_T, d_c) and K @ W_k → + # (BLOCK_S, d_i). fp32 accumulators; quantization happens after the loop. + # Requires d % BLOCK_D == 0. + Cq_dot = tl.zeros((BLOCK_T, d_c), dtype=tl.float32) + Hk_dot = tl.zeros((BLOCK_S, d_i), dtype=tl.float32) + for d_off in range(0, d, BLOCK_D): + rd = d_off + tl.arange(0, BLOCK_D) + + q_ptrs = Q_ptr + q_base + rt[:, None] * d + rd[None, :] + Q_chunk = tl.load(q_ptrs, mask=rt_mask[:, None], other=0.0) + k_ptrs = K_ptr + k_base + rs[:, None] * d + rd[None, :] + K_chunk = tl.load(k_ptrs, mask=rs_mask[:, None], other=0.0) + wdq_ptrs = W_dq_ptr + rd[:, None] * d_c + rdc[None, :] + Wdq_chunk = tl.load(wdq_ptrs) + wk_ptrs = W_k_ptr + rd[:, None] * d_i + rdi[None, :] + Wk_chunk = tl.load(wk_ptrs) + + Cq_dot = tl.dot(Q_chunk, Wdq_chunk, acc=Cq_dot) + Hk_dot = tl.dot(K_chunk, Wk_chunk, acc=Hk_dot) + + # Quantize Cq and Hk for the inner up-projection. + if IS_FP8: + Cq_amax = tl.maximum(tl.max(tl.abs(Cq_dot)), _FP8_AMAX_EPS) + Cq_inter = Cq_amax / _FP8_E4M3_MAX + C_q = (Cq_dot / Cq_inter).to(in_dtype) + Hk_amax = tl.maximum(tl.max(tl.abs(Hk_dot)), _FP8_AMAX_EPS) + Hk_inter = Hk_amax / _FP8_E4M3_MAX + Hk_T = tl.trans((Hk_dot / Hk_inter).to(in_dtype)) + else: + C_q = Cq_dot.to(in_dtype) + Hk_T = tl.trans(Hk_dot.to(in_dtype)) + Cq_inter = 1.0 + Hk_inter = 1.0 + + acc = tl.zeros((BLOCK_T, BLOCK_S), dtype=tl.float32) + + w_base = b * (oH * H * T_t) + h_outer * (H * T_t) + for h_idx in range(H): + # W_uq[h_idx, :, :] is a contiguous (d_c, d_i) block of W_uq (H, d_c, d_i). + wuq_ptrs = W_uq_ptr + h_idx * (d_c * d_i) + rdc[:, None] * d_i + rdi[None, :] + Wuq_h = tl.load(wuq_ptrs) + + # Hq = C_q @ W_uq[h_idx]: (BLOCK_T, d_i) + Hq_dot = tl.dot(C_q, Wuq_h) + if IS_FP8: + Hq_amax = tl.maximum(tl.max(tl.abs(Hq_dot)), _FP8_AMAX_EPS) + Hq_inter = Hq_amax / _FP8_E4M3_MAX + Hq_h = (Hq_dot / Hq_inter).to(in_dtype) + else: + Hq_h = Hq_dot.to(in_dtype) + Hq_inter = 1.0 + + # Hi = relu(Hq @ Hk^T): (BLOCK_T, BLOCK_S). FP8 MFMA in FP8 mode. + Hi_raw = tl.dot(Hq_h, Hk_T) + Hi = tl.maximum(Hi_raw, 0.0) + + # weights[b, h_outer, h_idx, t]: contiguous BLOCK_T-vector. + w_ptrs = weights_ptr + w_base + h_idx * T_t + rt + w_i = tl.load(w_ptrs, mask=rt_mask, other=0.0) + + if IS_FP8: + acc += Hi * (Hq_inter * w_i)[:, None] + else: + acc += Hi * w_i[:, None] + + # Apply combined per-tensor scale + carried-out intermediate scales. + scale = tl.load(scale_ptr) + if IS_FP8: + acc = acc * (scale * Cq_inter * Hk_inter) + else: + acc = acc * scale + + # Store O tile: (BLOCK_T, BLOCK_S). O has shape (B, oH, T, S). + o_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) + o_ptrs = O_ptr + o_base + rt[:, None] * T_s + rs[None, :] + tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), + mask=rt_mask[:, None] & rs_mask[None, :]) + + +# --- JAX primitive --------------------------------------------------------------- + +_indexer_p = extend_core.Primitive("te_indexer_triton") +_indexer_p.multiple_results = True + + +_FP8_DTYPES = frozenset([ + jnp.dtype("float8_e4m3fn"), + jnp.dtype("float8_e5m2"), + jnp.dtype("float8_e4m3fnuz"), + jnp.dtype("float8_e5m2fnuz"), +]) + + +def _is_fp8_dtype(dt): + return jnp.dtype(dt) in _FP8_DTYPES + + +@_indexer_p.def_abstract_eval +def _indexer_abstract(Q, K, W_uq, W_dq, W_k, weights, scale, *, out_dtype): + del W_uq, W_dq, W_k, weights, scale + B, oH, T_t, _ = Q.shape + _, _, T_s, _ = K.shape + return [core.ShapedArray((B, oH, T_t, T_s), out_dtype)] + + +_indexer_p.def_impl(functools.partial(xla.apply_primitive, _indexer_p)) + + +def _indexer_lowering(ctx, Q, K, W_uq, W_dq, W_k, weights, scale, *, out_dtype): + del out_dtype # baked into the output aval + Q_aval = ctx.avals_in[0] + K_aval = ctx.avals_in[1] + W_uq_aval = ctx.avals_in[2] + B, oH, T_t, d = Q_aval.shape + T_s = K_aval.shape[2] + H, d_c, d_i = W_uq_aval.shape + + is_fp8 = _is_fp8_dtype(Q_aval.dtype) + + # Per-config grid: BLOCK_T/BLOCK_S come from the autotuned config kwargs + # (or fall back to a sensible default if autotune is not active). + def grid_fn(merged_kwargs): + bt = merged_kwargs.get("BLOCK_T", 128) + bs = merged_kwargs.get("BLOCK_S", 64) + return (B * oH, triton.cdiv(T_t, bt), triton.cdiv(T_s, bs)) + + return triton_call_lowering( + ctx, + _indexer_kernel, + Q, + K, + W_uq, + W_dq, + W_k, + weights, + scale, + grid=grid_fn, + num_warps=4, + num_stages=1, + constexprs={ + "B": B, + "oH": oH, + "T_t": T_t, + "T_s": T_s, + "d": d, + "d_c": d_c, + "H": H, + "d_i": d_i, + "IS_FP8": is_fp8, + }, + ) + + +mlir.register_lowering(_indexer_p, _indexer_lowering, platform="rocm") +mlir.register_lowering(_indexer_p, _indexer_lowering, platform="cuda") + + +def indexer_fused_triton( + Q, + K, + W_uq, + W_dq, + W_k, + weights, + *, + scale_q=None, + scale_k=None, + scale_wq=None, + scale_wd=None, + scale_wk=None, + out_dtype=None, +): + """Raw-Triton low-rank indexer (BHSD). + + Args: + Q: (B, oH, T, d) high-precision (bf16/fp32) or FP8 e4m3 + K: (B, oH, S, d) must match Q's dtype + W_uq: (H, d_c, d_i) up-projection; must match Q's dtype + W_dq: (d, d_c) down-projection; must match Q's dtype + W_k: (d, d_i) key projection; must match Q's dtype + weights: (B, oH, H, T) high-precision regardless of Q dtype + scale_q, scale_k, scale_wq, scale_wd, scale_wk: + per-tensor fp32 dequant scales. All five required when Q is FP8. + out_dtype: dtype of the output O. Defaults to Q.dtype for non-FP8 and + weights.dtype (typically bf16) for FP8. + + BLOCK_T / BLOCK_S / BLOCK_D / num_warps / num_stages are autotuned at + first invocation per (IS_FP8, d, d_c, H, d_i) key. + + Returns: + O of shape (B, oH, T, S) + """ + if Q.ndim != 4 or K.ndim != 4 or weights.ndim != 4: + raise ValueError( + "indexer_fused_triton expects rank-4 BHSD Q, K, weights. Got " + f"Q.shape={Q.shape}, K.shape={K.shape}, weights.shape={weights.shape}." + ) + B, oH, T_t, d = Q.shape + Bk, oHk, T_s, dk = K.shape + H, d_c_uq, d_i = W_uq.shape + d_dq, d_c_dq = W_dq.shape + d_wk, d_i_wk = W_k.shape + Bw, oHw, Hw, T_w = weights.shape + if (Bk, oHk) != (B, oH): + raise ValueError(f"(B,oH) mismatch: Q has {(B, oH)}, K has {(Bk, oHk)}") + if not (d == dk == d_dq == d_wk): + raise ValueError(f"d mismatch across Q/K/W_dq/W_k: {d}, {dk}, {d_dq}, {d_wk}") + if d_c_uq != d_c_dq: + raise ValueError(f"d_c mismatch: W_uq has {d_c_uq}, W_dq has {d_c_dq}") + if d_i != d_i_wk: + raise ValueError(f"d_i mismatch: W_uq has {d_i}, W_k has {d_i_wk}") + if (Bw, oHw, Hw, T_w) != (B, oH, H, T_t): + raise ValueError( + f"weights shape {weights.shape} does not match expected " + f"(B={B}, oH={oH}, H={H}, T={T_t})" + ) + + is_fp8 = _is_fp8_dtype(Q.dtype) + if is_fp8: + for nm, t in (("K", K), ("W_uq", W_uq), ("W_dq", W_dq), ("W_k", W_k)): + if t.dtype != Q.dtype: + raise ValueError( + f"FP8 mode requires Q/K/W_uq/W_dq/W_k all match dtype; " + f"Q is {Q.dtype} but {nm} is {t.dtype}." + ) + scales = (scale_q, scale_k, scale_wq, scale_wd, scale_wk) + if any(s is None for s in scales): + raise ValueError( + "FP8 mode requires scale_q, scale_k, scale_wq, scale_wd, scale_wk." + ) + scale_combined = jnp.asarray( + jnp.float32(scale_q) * jnp.float32(scale_k) + * jnp.float32(scale_wq) * jnp.float32(scale_wd) + * jnp.float32(scale_wk), + dtype=jnp.float32, + ) + if out_dtype is None: + out_dtype = weights.dtype + else: + scale_combined = jnp.asarray(1.0, dtype=jnp.float32) + if out_dtype is None: + out_dtype = Q.dtype + + return _indexer_p.bind( + Q, + K, + W_uq, + W_dq, + W_k, + weights, + scale_combined, + out_dtype=jnp.dtype(out_dtype), + )[0] + + +# --- Score+ReLU+H-reduce fused kernel (hybrid backend) ------------------------- +# +# Inputs are *already projected*: Hq, Hk, W_o all come from upstream einsum +# calls (hipBLASLt). This kernel does only the score matmul, the relu, and the +# per-token-per-head weighted sum over H — the pieces that have no efficient +# einsum/HLO equivalent because they'd require materializing the (B, oH, T, H, S) +# pre-relu score tensor in HBM. By fusing them in registers, we eliminate that +# round-trip entirely. + +def _score_reduce_autotune_configs(): + # The kernel is dominated by Hq reads (one (BLOCK_T, d_i) load per H + # iteration). Bigger BLOCK_T ⇒ fewer T tiles ⇒ less total Hq traffic. + # Bigger BLOCK_S ⇒ more Hk reuse but bigger per-CTA footprint. + # + # BLOCK_T=512 was tried and consistently failed to launch on MI355X + # (resource exhaustion — VGPR/LDS budget for 64-iter H-loop with that + # large an accumulator). Capped at 256. + cfgs = [] + for bt in (64, 128, 256): + for bs in (32, 64, 128): + for num_warps in (4, 8): + for num_stages in (1, 2): + cfgs.append(triton.Config( + {"BLOCK_T": bt, "BLOCK_S": bs}, + num_warps=num_warps, num_stages=num_stages, + )) + # A few skinny / fat shapes the regular grid above won't hit. + cfgs += [ + triton.Config({"BLOCK_T": 32, "BLOCK_S": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 32, "BLOCK_S": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 256, "BLOCK_S": 32}, num_warps=8, num_stages=2), + ] + return cfgs + + +@triton.autotune(configs=_score_reduce_autotune_configs(), key=["H", "d_i"]) +@triton.jit +def _score_reduce_kernel( + Hq_ptr, # (B, oH, T_t, H, d_i) — produced by einsum("...tc,hci->...thi") + Hk_ptr, # (B, oH, T_s, d_i) + W_o_ptr, # (B, oH, T_t, H) + O_ptr, # (B, oH, T_t, T_s) + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h_outer) slice. + + Grid order: (cdiv(T_s, BLOCK_S), cdiv(T_t, BLOCK_T), B * oH). + + S is the fastest-dispatching axis so consecutive CTAs share (B*oH, T) + and vary only in S — they all read the same per-head Hq slab, hitting + L2 instead of HBM. Hq layout is the natural einsum output + (..., T, H, d_i); per-head loads are strided in T (stride H*d_i). + """ + pid_s = tl.program_id(0) + pid_t = tl.program_id(1) + pid_bh = tl.program_id(2) + + b = pid_bh // oH + h_outer = pid_bh % oH + + rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) + rdi = tl.arange(0, d_i) + + rt_mask = rt < T_t + rs_mask = rs < T_s + + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + o_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) + + # Load the (BLOCK_S, d_i) Hk slab once — it is loop-invariant over H. + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_tile = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + Hk_T = tl.trans(Hk_tile) # (d_i, BLOCK_S) + + acc = tl.zeros((BLOCK_T, BLOCK_S), dtype=tl.float32) + + for h in range(H): + hq_ptrs = (Hq_ptr + hq_base + + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :]) + Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) + + wo_ptrs = W_o_ptr + wo_base + rt * H + h + w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0) + + score = tl.dot(Hq_h, Hk_T) + score = tl.maximum(score, 0.0) + acc += score * w_h[:, None].to(tl.float32) + + o_ptrs = O_ptr + o_base + rt[:, None] * T_s + rs[None, :] + tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), + mask=rt_mask[:, None] & rs_mask[None, :]) + + +_score_reduce_p = extend_core.Primitive("te_indexer_score_reduce_triton") +_score_reduce_p.multiple_results = True + + +@_score_reduce_p.def_abstract_eval +def _score_reduce_abstract(Hq, Hk, W_o, *, out_dtype): + del W_o + # Hq layout: (B, oH, T_t, H, d_i) + B, oH, T_t, _H, _d_i = Hq.shape + T_s = Hk.shape[2] + return [core.ShapedArray((B, oH, T_t, T_s), out_dtype)] + + +_score_reduce_p.def_impl(functools.partial(xla.apply_primitive, _score_reduce_p)) + + +def _score_reduce_lowering(ctx, Hq, Hk, W_o, *, out_dtype): + del out_dtype + Hq_aval = ctx.avals_in[0] + Hk_aval = ctx.avals_in[1] + B, oH, T_t, H, d_i = Hq_aval.shape + T_s = Hk_aval.shape[2] + + def grid_fn(merged_kwargs): + bt = merged_kwargs.get("BLOCK_T", 64) + bs = merged_kwargs.get("BLOCK_S", 64) + # S as grid_x (fastest-dispatching) so per-(B*oH, T-tile) S workgroups + # cluster in time and hit L2 on the shared Hq slab. + return (triton.cdiv(T_s, bs), triton.cdiv(T_t, bt), B * oH) + + return triton_call_lowering( + ctx, + _score_reduce_kernel, + Hq, Hk, W_o, + grid=grid_fn, + num_warps=4, + num_stages=2, + constexprs={ + "B": B, + "oH": oH, + "T_t": T_t, + "T_s": T_s, + "H": H, + "d_i": d_i, + }, + ) + + +mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="rocm") +mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="cuda") + + +def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): + """Triton fused score-matmul + relu + per-(t, h) weighted H-reduction. + + Replaces the pattern: + + scores = relu(jnp.einsum("...thi,...si->...ths", Hq, Hk)) # never write + O = jnp.einsum("...ths,...th->...ts", scores, W_o) + + with a single kernel that holds the per-head score tile in registers, + avoiding the (B, oH, T, H, S) HBM round-trip that an einsum+XLA chain + pays (the dominant cost in profile_indexer's einsum baseline). + + Args: + Hq: (B, oH, T_t, H, d_i) + Hk: (B, oH, T_s, d_i) + W_o: (B, oH, T_t, H) + out_dtype: defaults to Hq.dtype. + + Returns: + O: (B, oH, T_t, T_s) + """ + if Hq.ndim != 5: + raise ValueError( + f"Hq must be rank-5 (B, oH, T_t, H, d_i); got shape {Hq.shape}" + ) + if Hk.ndim != 4: + raise ValueError( + f"Hk must be rank-4 (B, oH, T_s, d_i); got shape {Hk.shape}" + ) + if W_o.ndim != 4: + raise ValueError( + f"W_o must be rank-4 (B, oH, T_t, H); got shape {W_o.shape}" + ) + + B, oH, T_t, H, d_i = Hq.shape + Bk, oHk, T_s, d_i_k = Hk.shape + Bw, oHw, T_t_w, H_w = W_o.shape + if (Bk, oHk) != (B, oH): + raise ValueError( + f"(B, oH) mismatch: Hq has {(B, oH)}, Hk has {(Bk, oHk)}" + ) + if d_i != d_i_k: + raise ValueError(f"d_i mismatch: Hq has {d_i}, Hk has {d_i_k}") + if (Bw, oHw, T_t_w, H_w) != (B, oH, T_t, H): + raise ValueError( + f"W_o shape {W_o.shape} does not match expected " + f"(B={B}, oH={oH}, T_t={T_t}, H={H})" + ) + + if out_dtype is None: + out_dtype = Hq.dtype + + return _score_reduce_p.bind( + Hq, Hk, W_o, out_dtype=jnp.dtype(out_dtype) + )[0] + + +# --- Top-K fused variant ------------------------------------------------------- +# +# FlashAttention-style: each CTA owns one (b, h, t-tile) and serializes over s +# tiles, maintaining a running per-row top-k of the score matrix. Output is +# (B, H, T_t, k) values + (B, H, T_t, k) int32 indices — never materializes the +# full (T_t, T_s) score tensor. +# +# Top-k merge: pack (val_bits << 32) | idx_u32 into uint64, build (BLOCK_T, +# k+BLOCK_S) via gather+where (tl.cat is 1D-only on this Triton), sort +# descending, take first k. Constraints: k pow2, k+block_s pow2. +# +# Score values are post-ReLU (≥ 0) so the fp32 bit pattern sorts correctly as +# uint32. Init sentinel: (val=0.0, idx=0xFFFFFFFF) — real positive values +# displace it; rows with fewer than k positive scores trail with idx=-1. + +_DEFAULT_K = 64 + + +@triton.jit +def _indexer_topk_kernel( + Q_ptr, + K_ptr, + W_q_ptr, + W_k_ptr, + weights_ptr, + O_v_ptr, + O_i_ptr, + B: tl.constexpr, + H: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + d: tl.constexpr, + I: tl.constexpr, + d_i: tl.constexpr, + K_TOPK: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, + KS_SUM: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_t = tl.program_id(1) + b = pid_bh // H + h = pid_bh % H + + rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + rd = tl.arange(0, d) + rdi = tl.arange(0, d_i) + rt_mask = rt < T_t + + q_base = b * (H * T_t * d) + h * (T_t * d) + Q = tl.load(Q_ptr + q_base + rt[:, None] * d + rd[None, :], + mask=rt_mask[:, None], other=0.0) + + Wk = tl.load(W_k_ptr + rd[:, None] * d_i + rdi[None, :]) + + running_pack = tl.full((BLOCK_T, K_TOPK), 0xFFFFFFFF, dtype=tl.uint64) + + w_base = b * (H * T_t * I) + h * (T_t * I) + k_base = b * (H * T_s * d) + h * (T_s * d) + + n_s_tiles = T_s // BLOCK_S + for s_idx in range(n_s_tiles): + s_off = s_idx * BLOCK_S + rs = s_off + tl.arange(0, BLOCK_S) + + Kt = tl.load(K_ptr + k_base + rs[:, None] * d + rd[None, :]) + Hk = tl.dot(Kt, Wk).to(Q.dtype) + + acc = tl.zeros((BLOCK_T, BLOCK_S), dtype=tl.float32) + for i in range(I): + Wq_i = tl.load(W_q_ptr + i * (d * d_i) + rd[:, None] * d_i + rdi[None, :]) + Hq_i = tl.dot(Q, Wq_i).to(Q.dtype) + Hi_raw = tl.dot(Hq_i, tl.trans(Hk)) + Hi = tl.maximum(Hi_raw, 0.0) + w_i = tl.load(weights_ptr + w_base + rt * I + i, mask=rt_mask, other=0.0) + acc += Hi * w_i[:, None] + + # Encode fp32 -> monotonic uint32 (radix-sort fp32 trick) so negative + # acc values sort below positive ones. + acc_bits = acc.to(tl.uint32, bitcast=True) + acc_sext = (acc.to(tl.int32, bitcast=True) >> 31).to(tl.uint32) + enc_mask = acc_sext | tl.cast(0x80000000, tl.uint32) + acc_key = acc_bits ^ enc_mask # (BLOCK_T, BLOCK_S) u32 + tile_idx = rs.to(tl.uint32) + tile_v_u = acc_key.to(tl.uint64) + tile_i_u = tile_idx.to(tl.uint64) + tile_pack = (tile_v_u << 32) | tile_i_u[None, :].broadcast_to((BLOCK_T, BLOCK_S)) + + pos = tl.arange(0, KS_SUM) + r_idx = tl.minimum(pos, K_TOPK - 1) + t_idx = tl.maximum(pos.to(tl.int32) - K_TOPK, 0).to(tl.int32) + r_ext = tl.gather(running_pack, r_idx[None, :].broadcast_to((BLOCK_T, KS_SUM)), axis=1) + t_ext = tl.gather(tile_pack, t_idx[None, :].broadcast_to((BLOCK_T, KS_SUM)), axis=1) + combined = tl.where((pos < K_TOPK)[None, :].broadcast_to((BLOCK_T, KS_SUM)), + r_ext, t_ext) + + running_pack = tl.topk(combined, K_TOPK, dim=1) + + # Decode monotonic uint32 key -> fp32 bits. + out_key = (running_pack >> 32).to(tl.uint32) + out_key_sext = (~out_key.to(tl.int32, bitcast=True) >> 31).to(tl.uint32) + dec_mask = out_key_sext | tl.cast(0x80000000, tl.uint32) + out_bits = out_key ^ dec_mask + out_vals_fp32 = out_bits.to(tl.float32, bitcast=True) + out_idxs = (running_pack & 0xFFFFFFFF).to(tl.uint32).to(tl.int32) + + rk = tl.arange(0, K_TOPK) + o_base = b * (H * T_t * K_TOPK) + h * (T_t * K_TOPK) + tl.store(O_v_ptr + o_base + rt[:, None] * K_TOPK + rk[None, :], + out_vals_fp32.to(O_v_ptr.dtype.element_ty), + mask=rt_mask[:, None]) + tl.store(O_i_ptr + o_base + rt[:, None] * K_TOPK + rk[None, :], + out_idxs, mask=rt_mask[:, None]) + + +_indexer_topk_p = extend_core.Primitive("te_indexer_topk_triton") +_indexer_topk_p.multiple_results = True + + +@_indexer_topk_p.def_abstract_eval +def _indexer_topk_abstract(Q, K, W_q, W_k, weights, *, + k, block_t, block_s, num_warps, num_stages): + del W_q, W_k, weights, block_t, block_s, num_warps, num_stages + B, H, T_t, _ = Q.shape + return [ + core.ShapedArray((B, H, T_t, k), Q.dtype), + core.ShapedArray((B, H, T_t, k), jnp.int32), + ] + + +_indexer_topk_p.def_impl(functools.partial(xla.apply_primitive, _indexer_topk_p)) + + +def _indexer_topk_lowering(ctx, Q, K, W_q, W_k, weights, *, + k, block_t, block_s, num_warps, num_stages): + Q_aval = ctx.avals_in[0] + K_aval = ctx.avals_in[1] + W_q_aval = ctx.avals_in[2] + B, H, T_t, d = Q_aval.shape + T_s = K_aval.shape[2] + I, _, d_i = W_q_aval.shape + + grid = (B * H, triton.cdiv(T_t, block_t)) + + return triton_call_lowering( + ctx, + _indexer_topk_kernel, + Q, + K, + W_q, + W_k, + weights, + grid=grid, + num_warps=num_warps, + num_stages=num_stages, + constexprs={ + "B": B, + "H": H, + "T_t": T_t, + "T_s": T_s, + "d": d, + "I": I, + "d_i": d_i, + "K_TOPK": k, + "BLOCK_T": block_t, + "BLOCK_S": block_s, + "KS_SUM": k + block_s, + }, + ) + + +mlir.register_lowering(_indexer_topk_p, _indexer_topk_lowering, platform="rocm") +mlir.register_lowering(_indexer_topk_p, _indexer_topk_lowering, platform="cuda") + + +def _is_pow2(n): + return n > 0 and (n & (n - 1)) == 0 + + +def indexer_fused_topk_triton( + Q, + K, + W_q, + W_k, + weights, + *, + k: int = _DEFAULT_K, + block_t: int = 128, + block_s: int = 64, + num_warps: int = 4, + num_stages: int = 1, +): + """Fused indexer + per-row top-k along T_s. Returns (vals, idxs). + + vals: (B, H, T_t, k) Q.dtype — descending top-k post-ReLU scores + idxs: (B, H, T_t, k) int32 — corresponding s positions in [0, T_s) + + Constraints: + * Q, K, weights are rank-4 BHSD. + * T_s % block_s == 0 (no masking inside inner loop). + * k and (k + block_s) are powers of 2 (tl.sort and tl.arange). + """ + if Q.ndim != 4 or K.ndim != 4 or weights.ndim != 4: + raise ValueError( + "indexer_fused_topk_triton expects rank-4 BHSD Q, K, weights. Got " + f"Q.shape={Q.shape}, K.shape={K.shape}, weights.shape={weights.shape}." + ) + B, H, T_t, d = Q.shape + Bk, Hk, T_s, dk = K.shape + I, d2, d_i = W_q.shape + d3, d_i_k = W_k.shape + Bw, Hw, T_t_w, I_w = weights.shape + if (Bk, Hk) != (B, H): + raise ValueError(f"(B,H) mismatch: Q has {(B, H)}, K has {(Bk, Hk)}") + if not (d == dk == d2 == d3): + raise ValueError(f"d mismatch across Q/K/W_q/W_k: {d}, {dk}, {d2}, {d3}") + if d_i != d_i_k: + raise ValueError(f"d_i mismatch: W_q has {d_i}, W_k has {d_i_k}") + if (Bw, Hw, T_t_w, I_w) != (B, H, T_t, I): + raise ValueError( + f"weights shape {weights.shape} does not match expected " + f"(B={B}, H={H}, T_t={T_t}, I={I})" + ) + if k > T_s: + raise ValueError(f"k={k} exceeds T_s={T_s}") + + block_t = min(block_t, T_t) + block_s = min(block_s, T_s) + + if T_s % block_s != 0: + raise ValueError( + f"T_s={T_s} must be divisible by block_s={block_s} (kernel doesn't " + "mask invalid s positions in the inner loop)." + ) + if not _is_pow2(k): + raise ValueError(f"k={k} must be a power of 2 (tl.arange requirement)") + if not _is_pow2(k + block_s): + raise ValueError( + f"k + block_s = {k + block_s} must be a power of 2 " + f"(k={k}, block_s={block_s})" + ) + + return _indexer_topk_p.bind( + Q, K, W_q, W_k, weights, + k=k, + block_t=block_t, + block_s=block_s, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 6ea4092cb..219b286b6 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -32,6 +32,7 @@ import hashlib import os +import tempfile import warnings from typing import Any, Callable, Mapping import zlib @@ -151,7 +152,6 @@ def _check_triton_compatibility(): try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc - from triton.backends.nvidia import compiler as cb from triton.runtime import autotuner except ImportError as e: raise ImportError( @@ -161,6 +161,24 @@ def _check_triton_compatibility(): ) from e +# Detect target platform once at import time. AMD/HIP returns an arch string +# like "gfx950:sramecc+:xnack-"; NVIDIA returns something else (or this call +# falls through to the CUDA path). +try: + _ARCH_DETAILS = gpu_triton.get_arch_details(0) +except Exception: # noqa: BLE001 + _ARCH_DETAILS = "" +_IS_HIP = _ARCH_DETAILS.startswith("gfx") + +# Lazy backend imports — only pull in what the active platform needs so that +# AMD-only or NVIDIA-only environments don't fail at module load. +if _IS_HIP: + from triton.backends.amd import compiler as cb_hip # noqa: E402 + from triton.backends.compiler import GPUTarget as _TritonGPUTarget # noqa: E402 +else: + from triton.backends.nvidia import compiler as cb # noqa: E402 + + __all__ = ["triton_call_lowering", "get_triton_info"] # Triton kernel cache (module-level, shared across all kernels) @@ -212,6 +230,9 @@ def get_triton_dtype(aval): jnp.dtype("float16"): "fp16", jnp.dtype("float8_e4m3fn"): "fp8e4nv", jnp.dtype("float8_e5m2"): "fp8e5", + # AMD MI300 (gfx942) "FNUZ" variants — Triton calls these fp8e4b8/fp8e5b16. + jnp.dtype("float8_e4m3fnuz"): "fp8e4b8", + jnp.dtype("float8_e5m2fnuz"): "fp8e5b16", jnp.dtype("int64"): "i64", jnp.dtype("int32"): "i32", jnp.dtype("int16"): "i16", @@ -273,7 +294,51 @@ def compile_triton( if cache_key in _TRITON_KERNEL_CACHE: return _TRITON_KERNEL_CACHE[cache_key] - # Compile kernel + # Mark constants as constexpr in signature (defensive — tensor signatures + # built by triton_call_lowering won't contain constexpr names, but other + # callers might). + signature_with_constexpr = dict(signature) + for const_name in constants.keys(): + if const_name in signature_with_constexpr: + signature_with_constexpr[const_name] = "constexpr" + + if _IS_HIP: + kernel = _compile_triton_hip( + kernel_fn, + signature_with_constexpr, + constants, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion, + ) + else: + kernel = _compile_triton_cuda( + kernel_fn, + signature_with_constexpr, + constants, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion, + ) + + _TRITON_KERNEL_CACHE[cache_key] = kernel + return kernel + + +def _compile_triton_cuda( + kernel_fn, + signature, + constants, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion, +): options = cb.CUDAOptions( num_warps=num_warps, num_stages=num_stages, @@ -282,54 +347,106 @@ def compile_triton( debug=False, enable_fp_fusion=enable_fp_fusion, ) - - # Mark constants as constexpr in signature - signature_with_constexpr = dict(signature) - for const_name in constants.keys(): - if const_name in signature_with_constexpr: - signature_with_constexpr[const_name] = "constexpr" - - src = tc.ASTSource( - fn=kernel_fn, - constexprs=constants, - signature=signature_with_constexpr, - ) - + src = tc.ASTSource(fn=kernel_fn, constexprs=constants, signature=signature) compiled = tc.compile( src, target=tc.GPUTarget("cuda", compute_capability, 32), options=options.__dict__, ) - # Create kernel object for JAX - # From jax/jaxlib/gpu/triton_kernels.cc: from packaging import version if version.parse(jax.__version__) >= version.parse("0.8.2"): - kernel = gpu_triton.TritonKernel( - compiled.name, # arg0: kernel_name (str) - num_warps, # arg1: num_warps (int) - num_ctas, # arg2: num_ctas (int) - compiled.metadata.shared, # arg3: shared_mem_bytes (int) - compiled.asm["ptx"], # arg4: ptx (str) - "", # arg5: ttir (str) - empty - compute_capability, # arg6: compute_capability (int) - ) - else: - kernel = gpu_triton.TritonKernel( + return gpu_triton.TritonKernel( compiled.name, num_warps, + num_ctas, compiled.metadata.shared, compiled.asm["ptx"], - "", # ttir + "", compute_capability, - 1, - 1, - 1, ) + return gpu_triton.TritonKernel( + compiled.name, + num_warps, + compiled.metadata.shared, + compiled.asm["ptx"], + "", + compute_capability, + 1, + 1, + 1, + ) - _TRITON_KERNEL_CACHE[cache_key] = kernel - return kernel + +# Track HSACO temp files for the lifetime of the process so the kernel paths +# we hand to jaxlib don't get garbage-collected. +_HSACO_TEMP_FILES: list[str] = [] + + +def _compile_triton_hip( + kernel_fn, + signature, + constants, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion, +): + # Strip target-feature suffix: "gfx950:sramecc+:xnack-" -> "gfx950". + arch = _ARCH_DETAILS.split(":", 1)[0] + # Mirror what triton's parse_options would do per-arch: the default + # HIPOptions.supported_fp8_dtypes is just ("fp8e5",), and constructing + # HIPOptions directly bypasses the per-arch augmentation. Set it + # explicitly so FP8 e4m3 kernels compile on gfx942/gfx950. + if arch == "gfx942": + fp8_dtypes = ("fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16") + elif arch == "gfx950" or arch.startswith("gfx12"): + fp8_dtypes = ("fp8e4nv", "fp8e5") + else: + fp8_dtypes = ("fp8e5",) + options = cb_hip.HIPOptions( + num_warps=num_warps, + num_stages=num_stages, + num_ctas=num_ctas, + cluster_dims=(1, 1, 1), + debug=False, + enable_fp_fusion=enable_fp_fusion, + arch=arch, + supported_fp8_dtypes=fp8_dtypes, + ) + src = tc.ASTSource(fn=kernel_fn, constexprs=constants, signature=signature) + compiled = tc.compile( + src, + target=_TritonGPUTarget("hip", arch, warp_size=64), + options=options.__dict__, + ) + + # jaxlib's HIP TritonKernel ctor takes a path to an HSACO blob, not bytes. + fd, hsaco_path = tempfile.mkstemp(suffix=".hsaco", prefix=f"te_{compiled.name}_") + with os.fdopen(fd, "wb") as f: + f.write(compiled.asm["hsaco"]) + _HSACO_TEMP_FILES.append(hsaco_path) + + # The HIP TritonKernel constructor on this jax/jaxlib (0.8.0) takes + # `shared_mem_bytes` in slot 2 — not slot 5 as the public sample code + # suggests. The sample only works for kernels whose `shared` is 0 + # (e.g. simple element-wise kernels), because there the misplaced 0 in + # slot 2 coincidentally matches the expected layout. Kernels using + # tl.dot need real LDS allocation and silently produce garbage when + # `shared` lands in the wrong constructor slot. + return gpu_triton.TritonKernel( + compiled.name, + num_warps, + compiled.metadata.shared, + hsaco_path, + str(compiled.asm.get("ttir", "")), + compute_capability, + 1, + 1, + 1, + ) def triton_call_lowering( @@ -339,6 +456,9 @@ def triton_call_lowering( grid, input_output_aliases: Mapping[int, int] = None, constexprs: Mapping[str, Any] = None, + num_warps: int = 32, + num_stages: int = 1, + num_ctas: int = 1, ): """Helper for MLIR lowering that calls a Triton kernel. @@ -348,7 +468,12 @@ def triton_call_lowering( ctx: MLIR lowering context kernel_fn: Triton kernel function *array_args: Input arrays (from ctx) - grid: Grid dimensions (int or tuple) + grid: Grid dimensions. Either: + * an int / 1-3 element tuple (fixed grid), OR + * a callable ``(merged_kwargs) -> tuple`` for autotuned kernels + whose grid depends on the autotune-selected meta-args + (e.g. BLOCK_T/BLOCK_S). ``merged_kwargs`` is the union of + ``constexprs`` and the per-config ``Config.kwargs``. input_output_aliases: Mapping of input to output aliases constexprs: Compile-time constants for the kernel. This includes both tl.constexpr arguments AND scalar runtime arguments (like @@ -389,23 +514,28 @@ def lowering(ctx, x, *, block_size): tensor_arg_names = [n for n in arg_names if n not in constexpr_names] signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} - # Normalize grid to 3D - if isinstance(grid, int): - grid_tuple = (grid, 1, 1) - elif len(grid) == 1: - grid_tuple = (grid[0], 1, 1) - elif len(grid) == 2: - grid_tuple = (grid[0], grid[1], 1) + # Normalize grid to 3D. `grid` may be a callable for autotuned kernels + # whose grid depends on the per-config meta-args (BLOCK_T/BLOCK_S etc.). + grid_fn = grid if callable(grid) else None + + def _normalize_grid(g): + if isinstance(g, int): + return (g, 1, 1) + if len(g) == 1: + return (g[0], 1, 1) + if len(g) == 2: + return (g[0], g[1], 1) + return g[:3] + + if grid_fn is None: + grid_tuple = _normalize_grid(grid) else: - grid_tuple = grid[:3] + # For non-autotune fallback, evaluate with just the user constexprs. + grid_tuple = _normalize_grid(grid_fn(constexprs or {})) - # Default values for the kernel + # Caller-supplied num_warps/num_stages/num_ctas (defaults match the + # historical hardcoded values: 32/1/1). actual_kernel_fn = kernel_fn - num_warps = 32 - num_stages = ( - 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas - ) - num_ctas = 1 kernel_constexprs = constexprs if constexprs is not None else {} # Handle autotuned kernels - compile all configs @@ -415,7 +545,8 @@ def lowering(ctx, x, *, block_size): kernel_calls = [] actual_kernel_fn = kernel_fn.fn - for config in kernel_fn.configs: + for idx, config in enumerate(kernel_fn.configs): + print(f"DEBUG *** Running config {idx+1}/{len(kernel_fn.configs)}") # Extract parameters from config config_num_warps = config.num_warps if config.num_warps is not None else num_warps config_num_stages = config.num_stages if config.num_stages is not None else num_stages @@ -424,6 +555,14 @@ def lowering(ctx, x, *, block_size): # Merge config kwargs with user constexprs config_constexprs = {**config.kwargs, **(constexprs if constexprs else {})} + # Per-config grid: re-evaluate grid_fn with this config's merged + # kwargs so configs that vary BLOCK_T/BLOCK_S launch at the right + # cdiv(T_t, BLOCK_T) etc. + if grid_fn is not None: + config_grid = _normalize_grid(grid_fn(config_constexprs)) + else: + config_grid = grid_tuple + # Compile this config config_kernel = compile_triton( actual_kernel_fn, @@ -443,9 +582,9 @@ def lowering(ctx, x, *, block_size): config_call = gpu_triton.TritonKernelCall( config_kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + config_grid[0], + config_grid[1], + config_grid[2], config_params, ) From 1e182520ce45908f3b593273924c9c91a5e9c4b8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 7 May 2026 19:01:20 +0000 Subject: [PATCH 02/17] Dirty commit --- transformer_engine/jax/indexer.py | 46 ++++++++---- .../jax/triton_extensions/indexer.py | 71 ++++++++++++++----- 2 files changed, 88 insertions(+), 29 deletions(-) diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/indexer.py index 141e14275..48b67ce48 100644 --- a/transformer_engine/jax/indexer.py +++ b/transformer_engine/jax/indexer.py @@ -139,16 +139,23 @@ def _indexer_impl_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs): def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, scale_q=None, scale_k=None, scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None): + out_dtype=None, + fp8_score=False): """Einsum projections + Triton score-relu-reduce. Mirrors ``_indexer_impl_reference`` for the four projections (which - lower to hipBLASLt GEMMs), then hands Hq / Hk / W_o to a fused Triton - kernel that does score+relu+H-reduction in registers — eliminating the - 16+ GB pre-relu-score HBM round-trip the pure-einsum path pays. - - bf16 only for now. FP8 inputs are dequantized to bf16 just like the - reference; native FP8 GEMM is not available on ROCm anyway. + lower to hipBLASLt bf16 GEMMs), then hands Hq / Hk / W_o to a fused + Triton kernel that does score+relu+H-reduction in registers — + eliminating the (B, oH, T, H, S) pre-relu-score HBM round-trip the + pure-einsum path pays. + + fp8_score (default True): per-tensor amax-quantize Hq and Hk to fp8 + e4m3 just before the kernel call. The kernel's score MFMA then runs + native fp8-fp8 (`v_mfma_f32_*_fp8_fp8` on gfx950) and Hq's HBM + footprint halves, which dominates the kernel's read bandwidth at + production sizes. The two scales fold into one fp32 multiply at the + end of the kernel (relu commutes with positive scaling). W_o stays + bf16 since it's tiny. """ from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton @@ -178,8 +185,15 @@ def _dq(x, s): H_k = jnp.einsum("...sd,di->...si", K_d, W_k_d) # (..., S, d_i) W_o = jnp.einsum("...td,dh->...th", Q_d, W_w.astype(wp)) # (..., T, H) - O = score_reduce_triton(H_q, H_k, W_o, - out_dtype=out_dtype if out_dtype else wp) + if fp8_score: + H_q_fp8, sq = quantize_to_fp8(H_q, dtype=jnp.float8_e4m3fn) + H_k_fp8, sk = quantize_to_fp8(H_k, dtype=jnp.float8_e4m3fn) + O = score_reduce_triton(H_q_fp8, H_k_fp8, W_o, + scale_hq=sq, scale_hk=sk, + out_dtype=out_dtype if out_dtype else wp) + else: + O = score_reduce_triton(H_q, H_k, W_o, + out_dtype=out_dtype if out_dtype else wp) return O @@ -212,11 +226,11 @@ def indexer_topk(Q, K, W_uq, W_dq, W_k, weights, *, k, backend="triton"): # --- Top-level dispatch --------------------------------------------------------- -@functools.partial(jax.jit, static_argnames=("backend", "out_dtype")) +@functools.partial(jax.jit, static_argnames=("backend", "out_dtype", "fp8_score")) def indexer(Q, K, W_uq, W_dq, W_k, weights, *, scale_q=None, scale_k=None, scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None, backend="reference"): + out_dtype=None, backend="reference", fp8_score=False): """Low-rank lightning-indexer. Args: @@ -233,7 +247,12 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, W_dq itself is FP8. out_dtype: output dtype override (defaults to Q.dtype, or weights.dtype in FP8 mode). - backend: "reference", "fused" (Pallas), or "triton". + backend: "reference", "fused" (Pallas), "triton", or "hybrid". + fp8_score: hybrid backend only — when True, quantize Hq and Hk to + fp8 e4m3 before the score-reduce kernel so the score MFMA + runs native fp8 and Hq's HBM footprint halves. Pays off + once Hq is large enough that the savings exceed the + amax-quantize cost (typically production-sized shapes). Returns: O of shape (..., T, S). @@ -250,7 +269,8 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, if backend == "triton": return _indexer_impl_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) if backend == "hybrid": - return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, + fp8_score=fp8_score, **fp8_kwargs) raise ValueError( f"unknown backend {backend!r}; expected 'reference', 'fused', 'triton', " f"or 'hybrid'" diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index 22f8de2ec..322530978 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -430,12 +430,13 @@ def _score_reduce_autotune_configs(): return cfgs -@triton.autotune(configs=_score_reduce_autotune_configs(), key=["H", "d_i"]) +@triton.autotune(configs=_score_reduce_autotune_configs(), key=["H", "d_i", "IS_FP8"]) @triton.jit def _score_reduce_kernel( - Hq_ptr, # (B, oH, T_t, H, d_i) — produced by einsum("...tc,hci->...thi") - Hk_ptr, # (B, oH, T_s, d_i) - W_o_ptr, # (B, oH, T_t, H) + Hq_ptr, # (B, oH, T_t, H, d_i) bf16 OR fp8 e4m3 + Hk_ptr, # (B, oH, T_s, d_i) same dtype as Hq + W_o_ptr, # (B, oH, T_t, H) bf16 always + scale_ptr, # 0-D fp32: combined scale_hq * scale_hk (1.0 in bf16 mode) O_ptr, # (B, oH, T_t, T_s) B: tl.constexpr, oH: tl.constexpr, @@ -445,6 +446,7 @@ def _score_reduce_kernel( d_i: tl.constexpr, BLOCK_T: tl.constexpr, BLOCK_S: tl.constexpr, + IS_FP8: tl.constexpr, ): """Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h_outer) slice. @@ -454,6 +456,10 @@ def _score_reduce_kernel( and vary only in S — they all read the same per-head Hq slab, hitting L2 instead of HBM. Hq layout is the natural einsum output (..., T, H, d_i); per-head loads are strided in T (stride H*d_i). + + FP8 mode (IS_FP8=True): Hq and Hk are e4m3 with per-tensor fp32 scales. + The two scales fold into one fp32 multiply at the end (relu commutes + with positive scaling). The score MFMA runs native fp8-fp8. """ pid_s = tl.program_id(0) pid_t = tl.program_id(1) @@ -489,10 +495,20 @@ def _score_reduce_kernel( wo_ptrs = W_o_ptr + wo_base + rt * H + h w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0) + # tl.dot lowers to native fp8-fp8 MFMA when both inputs are fp8; + # otherwise bf16-bf16 MFMA. Output is fp32 in both cases. score = tl.dot(Hq_h, Hk_T) score = tl.maximum(score, 0.0) acc += score * w_h[:, None].to(tl.float32) + # Apply the combined per-tensor dequant scale at the very end. relu is + # invariant under multiplication by a positive scalar (sq * sk > 0), + # so this is mathematically equivalent to scaling Hq_h and Hk_T per + # iteration but costs one fp32 multiply per output element instead of + # one per dot input. + scale = tl.load(scale_ptr) + acc = acc * scale + o_ptrs = O_ptr + o_base + rt[:, None] * T_s + rs[None, :] tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=rt_mask[:, None] & rs_mask[None, :]) @@ -503,8 +519,8 @@ def _score_reduce_kernel( @_score_reduce_p.def_abstract_eval -def _score_reduce_abstract(Hq, Hk, W_o, *, out_dtype): - del W_o +def _score_reduce_abstract(Hq, Hk, W_o, scale, *, out_dtype): + del W_o, scale # Hq layout: (B, oH, T_t, H, d_i) B, oH, T_t, _H, _d_i = Hq.shape T_s = Hk.shape[2] @@ -514,12 +530,13 @@ def _score_reduce_abstract(Hq, Hk, W_o, *, out_dtype): _score_reduce_p.def_impl(functools.partial(xla.apply_primitive, _score_reduce_p)) -def _score_reduce_lowering(ctx, Hq, Hk, W_o, *, out_dtype): +def _score_reduce_lowering(ctx, Hq, Hk, W_o, scale, *, out_dtype): del out_dtype Hq_aval = ctx.avals_in[0] Hk_aval = ctx.avals_in[1] B, oH, T_t, H, d_i = Hq_aval.shape T_s = Hk_aval.shape[2] + is_fp8 = _is_fp8_dtype(Hq_aval.dtype) def grid_fn(merged_kwargs): bt = merged_kwargs.get("BLOCK_T", 64) @@ -531,7 +548,7 @@ def grid_fn(merged_kwargs): return triton_call_lowering( ctx, _score_reduce_kernel, - Hq, Hk, W_o, + Hq, Hk, W_o, scale, grid=grid_fn, num_warps=4, num_stages=2, @@ -542,6 +559,7 @@ def grid_fn(merged_kwargs): "T_s": T_s, "H": H, "d_i": d_i, + "IS_FP8": is_fp8, }, ) @@ -550,7 +568,8 @@ def grid_fn(merged_kwargs): mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="cuda") -def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): +def score_reduce_triton(Hq, Hk, W_o, *, + scale_hq=None, scale_hk=None, out_dtype=None): """Triton fused score-matmul + relu + per-(t, h) weighted H-reduction. Replaces the pattern: @@ -563,10 +582,13 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): pays (the dominant cost in profile_indexer's einsum baseline). Args: - Hq: (B, oH, T_t, H, d_i) - Hk: (B, oH, T_s, d_i) - W_o: (B, oH, T_t, H) - out_dtype: defaults to Hq.dtype. + Hq: (B, oH, T_t, H, d_i) bf16 OR fp8 e4m3 + Hk: (B, oH, T_s, d_i) must match Hq.dtype + W_o: (B, oH, T_t, H) bf16 + scale_hq, scale_hk: + per-tensor fp32 dequant scales for Hq / Hk. Required when + Hq is FP8; ignored otherwise. + out_dtype: defaults to Hq.dtype (or W_o.dtype in FP8 mode). Returns: O: (B, oH, T_t, T_s) @@ -599,11 +621,28 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): f"(B={B}, oH={oH}, T_t={T_t}, H={H})" ) - if out_dtype is None: - out_dtype = Hq.dtype + is_fp8 = _is_fp8_dtype(Hq.dtype) + if is_fp8: + if Hk.dtype != Hq.dtype: + raise ValueError( + f"FP8 mode requires Hk.dtype == Hq.dtype; " + f"Hq is {Hq.dtype} but Hk is {Hk.dtype}." + ) + if scale_hq is None or scale_hk is None: + raise ValueError("FP8 mode requires scale_hq and scale_hk.") + scale_combined = jnp.asarray( + jnp.float32(scale_hq) * jnp.float32(scale_hk), + dtype=jnp.float32, + ) + if out_dtype is None: + out_dtype = W_o.dtype + else: + scale_combined = jnp.asarray(1.0, dtype=jnp.float32) + if out_dtype is None: + out_dtype = Hq.dtype return _score_reduce_p.bind( - Hq, Hk, W_o, out_dtype=jnp.dtype(out_dtype) + Hq, Hk, W_o, scale_combined, out_dtype=jnp.dtype(out_dtype) )[0] From fdfc1a410d58b28284033cd7b2b80b1ace737518 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 7 May 2026 19:38:25 +0000 Subject: [PATCH 03/17] Trimmed API --- transformer_engine/jax/indexer.py | 194 ++--- .../jax/pallas_kernels/__init__.py | 16 - .../jax/pallas_kernels/indexer.py | 271 ------- .../jax/triton_extensions/__init__.py | 6 +- .../jax/triton_extensions/indexer.py | 726 +----------------- 5 files changed, 83 insertions(+), 1130 deletions(-) delete mode 100644 transformer_engine/jax/pallas_kernels/__init__.py delete mode 100644 transformer_engine/jax/pallas_kernels/indexer.py diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/indexer.py index 48b67ce48..d6021f447 100644 --- a/transformer_engine/jax/indexer.py +++ b/transformer_engine/jax/indexer.py @@ -1,36 +1,33 @@ """Indexer op (forward only). -Two backends: - * "reference" - jnp/einsum, accepts arbitrary leading dims (..., T, d). - * "fused" - Pallas kernel, strict BHSD (B, H, T, d). Lives in - transformer_engine/jax/pallas_kernels/indexer.py. +Two canonical backends: + * ``"reference"`` — pure ``jnp.einsum``. Materializes the + (B, oH, T, H, S) pre-relu score tensor in HBM via hipBLASLt. + * ``"hybrid"`` — same einsum projections (C_q, H_q, H_k, W_o) followed + by a fused Triton kernel that does score+relu+H-reduction in + registers. Avoids the score-tensor HBM round-trip that dominates the + reference path. -Top-level entry point: ``indexer(Q, K, W_uq, W_dq, W_k, weights, *, backend=...)``. +Top-level entry point: ``indexer(Q, K, W_uq, W_dq, W_k, W_w, *, backend=...)``. Math (low-rank form: Q is hidden state; query heads are produced by a -down-projection (d -> d_c) followed by an up-projection (d_c -> H * d_i)): +down-projection (d -> d_c) followed by an up-projection (d_c -> H * d_i); +output weights are produced from Q via a learnable d -> H projection): C_q = Q @ W_dq # (..., T, d_c) H_q = einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) H_k = K @ W_k # (..., S, d_i) + W_o = Q @ W_w # (..., T, H) H = relu(einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) - O = einsum("...ths,...ht->...ts", H, weights) # (..., T, S) - -``weights`` is the precomputed per-(indexer-head, token) weight (DeepSeek's -``weights_proj(x)`` term, transposed for kernel-friendly layout). Its leading -dims must broadcast against Q's. - -FP8 mode: any of Q / K / W_uq / W_dq / W_k may be FP8 (e4m3) tensors. Each -FP8 operand needs a per-tensor fp32 scale (scale_q, scale_k, scale_wq, -scale_wd, scale_wk). ReLU commutes with positive scaling so the active -scales fold into a single fp32 scalar applied once at the end. Letting W_dq -go FP8 unlocks a native FP8 MFMA on the Q @ W_dq down-projection (and saves -half the bytes for that weight) at the cost of additional quantization noise -in the bottleneck of the low-rank decomposition. + O = einsum("...ths,...th->...ts", H, W_o) # (..., T, S) + +FP8 mode: any of Q / K / W_uq / W_dq / W_k may be FP8 (e4m3) tensors with +per-tensor fp32 scales. They are dequantized to bf16 inside both backends +before the projections — XLA on ROCm has no fp8 GEMM rewriter, and the +hybrid kernel itself runs in bf16. """ import functools -import math import jax import jax.numpy as jnp @@ -121,26 +118,12 @@ def _dq(x, s): return O -# --- Fused implementation (Pallas) ---------------------------------------------- -# Imported lazily so callers without Triton/Pallas can still use the reference. - -def _indexer_impl_fused(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs): - raise NotImplementedError( - "Pallas backend has not yet been updated for the low-rank indexer form " - "(W_uq + W_dq). Use backend='triton' or backend='reference'." - ) - - -def _indexer_impl_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs): - from transformer_engine.jax.triton_extensions.indexer import indexer_fused_triton - return indexer_fused_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) - +# --- Hybrid implementation (einsum projections + Triton score-reduce) --------- def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, scale_q=None, scale_k=None, scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None, - fp8_score=False): + out_dtype=None): """Einsum projections + Triton score-relu-reduce. Mirrors ``_indexer_impl_reference`` for the four projections (which @@ -148,14 +131,6 @@ def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, Triton kernel that does score+relu+H-reduction in registers — eliminating the (B, oH, T, H, S) pre-relu-score HBM round-trip the pure-einsum path pays. - - fp8_score (default True): per-tensor amax-quantize Hq and Hk to fp8 - e4m3 just before the kernel call. The kernel's score MFMA then runs - native fp8-fp8 (`v_mfma_f32_*_fp8_fp8` on gfx950) and Hq's HBM - footprint halves, which dominates the kernel's read bandwidth at - production sizes. The two scales fold into one fp32 multiply at the - end of the kernel (relu commutes with positive scaling). W_o stays - bf16 since it's tiny. """ from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton @@ -185,52 +160,18 @@ def _dq(x, s): H_k = jnp.einsum("...sd,di->...si", K_d, W_k_d) # (..., S, d_i) W_o = jnp.einsum("...td,dh->...th", Q_d, W_w.astype(wp)) # (..., T, H) - if fp8_score: - H_q_fp8, sq = quantize_to_fp8(H_q, dtype=jnp.float8_e4m3fn) - H_k_fp8, sk = quantize_to_fp8(H_k, dtype=jnp.float8_e4m3fn) - O = score_reduce_triton(H_q_fp8, H_k_fp8, W_o, - scale_hq=sq, scale_hk=sk, - out_dtype=out_dtype if out_dtype else wp) - else: - O = score_reduce_triton(H_q, H_k, W_o, - out_dtype=out_dtype if out_dtype else wp) + O = score_reduce_triton(H_q, H_k, W_o, + out_dtype=out_dtype if out_dtype else wp) return O -def _indexer_topk_impl_reference(Q, K, W_uq, W_dq, W_k, weights, k): - scores = _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights) - return jax.lax.top_k(scores, k) - - -def _indexer_topk_impl_triton(Q, K, W_uq, W_dq, W_k, weights, k): - from transformer_engine.jax.triton_extensions.indexer import indexer_fused_topk_triton - return indexer_fused_topk_triton(Q, K, W_uq, W_dq, W_k, weights, k=k) - - -@functools.partial(jax.jit, static_argnames=("k", "backend")) -def indexer_topk(Q, K, W_uq, W_dq, W_k, weights, *, k, backend="triton"): - """Indexer fused with per-row top-k along T_s. - - Returns (vals, idxs): - vals: (..., T, k) Q.dtype - idxs: (..., T, k) int32 - - backend: "reference" (full score then jax.lax.top_k) or "triton" (fused). - """ - if backend == "reference": - return _indexer_topk_impl_reference(Q, K, W_uq, W_dq, W_k, weights, k) - if backend == "triton": - return _indexer_topk_impl_triton(Q, K, W_uq, W_dq, W_k, weights, k) - raise ValueError(f"unknown backend {backend!r}; expected 'reference' or 'triton'") - - # --- Top-level dispatch --------------------------------------------------------- -@functools.partial(jax.jit, static_argnames=("backend", "out_dtype", "fp8_score")) +@functools.partial(jax.jit, static_argnames=("backend", "out_dtype")) def indexer(Q, K, W_uq, W_dq, W_k, weights, *, scale_q=None, scale_k=None, scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None, backend="reference", fp8_score=False): + out_dtype=None, backend="reference"): """Low-rank lightning-indexer. Args: @@ -239,20 +180,17 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, W_uq: (H, d_c, d_i) up-projection: d_c -> d_i (per head) W_dq: (d, d_c) down-projection: d -> d_c W_k: (d, d_i) key projection - weights: (..., H, T) per-(indexer-head, token) weight + weights: (d, H) learnable output-weight projection + (W_o = Q @ weights inside the impl) scale_q, scale_k, scale_wq, scale_wk: per-tensor fp32 dequant scales. Required when Q is FP8. scale_wd: per-tensor fp32 dequant scale for W_dq. Required only when W_dq itself is FP8. - out_dtype: output dtype override (defaults to Q.dtype, or weights.dtype - in FP8 mode). - backend: "reference", "fused" (Pallas), "triton", or "hybrid". - fp8_score: hybrid backend only — when True, quantize Hq and Hk to - fp8 e4m3 before the score-reduce kernel so the score MFMA - runs native fp8 and Hq's HBM footprint halves. Pays off - once Hq is large enough that the savings exceed the - amax-quantize cost (typically production-sized shapes). + out_dtype: output dtype override (defaults to Q.dtype, or bf16 for + the hybrid backend). + backend: "reference" (pure einsum) or "hybrid" (einsum projections + + Triton score-relu-reduce kernel). Returns: O of shape (..., T, S). @@ -264,76 +202,48 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, ) if backend == "reference": return _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) - if backend == "fused": - return _indexer_impl_fused(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) - if backend == "triton": - return _indexer_impl_triton(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) if backend == "hybrid": - return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, - fp8_score=fp8_score, **fp8_kwargs) + return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) raise ValueError( - f"unknown backend {backend!r}; expected 'reference', 'fused', 'triton', " - f"or 'hybrid'" + f"unknown backend {backend!r}; expected 'reference' or 'hybrid'" ) # --- Tests ---------------------------------------------------------------------- -def _reference_nobatch(Q, K, W_uq, W_dq, W_k, weights): - """Rank-2 reference (no leading dims) used as the cross-check.""" - C_q = Q @ W_dq - H_q = jnp.einsum("tc,hci->thi", C_q, W_uq) - H_k = K @ W_k - H = jax.nn.relu(jnp.einsum("thi,si->ths", H_q, H_k)) - return jnp.einsum("ths,ht->ts", H, weights) - - def _run_test(leading_shape, seed, backend): - # Power-of-2 shapes, all matmul dims >= 16 so the Pallas backend accepts. - T_t, T_s, d, d_c, H, d_i = 16, 16, 16, 16, 4, 16 + # The hybrid backend's Triton kernel requires rank-4 BHSD inputs, so this + # smoke test only exercises that shape (and the reference-vs-hybrid agreement). + T_t, T_s, d, d_c, H, d_i = 64, 64, 32, 32, 8, 32 keys = jax.random.split(jax.random.PRNGKey(seed), 6) - Q = jax.random.normal(keys[0], (*leading_shape, T_t, d)) - K = jax.random.normal(keys[1], (*leading_shape, T_s, d)) - W_uq = jax.random.normal(keys[2], (H, d_c, d_i)) - W_dq = jax.random.normal(keys[3], (d, d_c)) - W_k = jax.random.normal(keys[4], (d, d_i)) - weights = jax.random.normal(keys[5], (*leading_shape, H, T_t)) + Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (*leading_shape, T_s, d), dtype=jnp.bfloat16) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) + W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) try: - O = indexer(Q, K, W_uq, W_dq, W_k, weights, backend=backend) + O_ref = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") + O_b = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) except Exception as e: # noqa: BLE001 print(f" backend={backend:<10s} leading={str(leading_shape):10s} " f"SKIP: {type(e).__name__}: {str(e).splitlines()[0]}") return - flat = math.prod(leading_shape) if leading_shape else 1 - Q_f = Q.reshape(flat, T_t, d) - K_f = K.reshape(flat, T_s, d) - weights_f = weights.reshape(flat, T_t, H) - O_ref = jax.vmap(lambda q, k, w: _reference_nobatch(q, k, W_uq, W_dq, W_k, w))( - Q_f, K_f, weights_f - ) - O_ref = O_ref.reshape(*leading_shape, T_t, T_s) - - expected_shape = (*leading_shape, T_t, T_s) - shape_ok = O.shape == expected_shape - max_err = float(jnp.max(jnp.abs(O - O_ref))) - tag = "OK" if shape_ok and max_err < 1e-4 else "FAIL" + diff = (O_ref.astype(jnp.float32) - O_b.astype(jnp.float32)) + rel_err = float(jnp.linalg.norm(diff) / + (jnp.linalg.norm(O_ref.astype(jnp.float32)) + 1e-30)) + tag = "OK" if rel_err < 5e-3 else "FAIL" print(f" backend={backend:<10s} leading={str(leading_shape):10s} " - f"O.shape={O.shape} max abs err={max_err:.2e} [{tag}]") + f"O.shape={O_b.shape} rel.err={rel_err:.2e} [{tag}]") if __name__ == "__main__": - print("=== reference backend ===") - for i, leading in enumerate([(), (2,), (2, 3)]): + print("=== reference vs reference (sanity) ===") + for i, leading in enumerate([(2, 3),]): _run_test(leading, seed=i, backend="reference") - # Fused: strictly BHSD (rank-4 Q/K), so only the (2, 3) case applies. - print("\n=== fused backend ===") - for i, leading in enumerate([(), (2,), (2, 3)]): - _run_test(leading, seed=100 + i, backend="fused") - - # Triton: same BHSD-only contract as fused. - print("\n=== triton backend ===") - for i, leading in enumerate([(), (2,), (2, 3)]): - _run_test(leading, seed=200 + i, backend="triton") + print("\n=== hybrid vs reference ===") + for i, leading in enumerate([(2, 3),]): + _run_test(leading, seed=100 + i, backend="hybrid") diff --git a/transformer_engine/jax/pallas_kernels/__init__.py b/transformer_engine/jax/pallas_kernels/__init__.py deleted file mode 100644 index 7cdb7834b..000000000 --- a/transformer_engine/jax/pallas_kernels/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Pallas-backed kernels for Transformer Engine JAX. - -Pallas is the JAX-native kernel-authoring API. Its Triton lowering uses the -new ``__gpu$xla.gpu.triton`` custom call, which works on both NVIDIA and -AMD/ROCm. The kernel body looks essentially like a Triton kernel - same -program_id / block-pointer model - but uses ``pl.*`` primitives instead of -``tl.*`` so it composes cleanly with JAX (jit, sharding, dtype rules). -""" - -from .indexer import indexer_fused - -__all__ = ["indexer_fused"] diff --git a/transformer_engine/jax/pallas_kernels/indexer.py b/transformer_engine/jax/pallas_kernels/indexer.py deleted file mode 100644 index 7ee3508cc..000000000 --- a/transformer_engine/jax/pallas_kernels/indexer.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Pallas kernel for the fused indexer (BHSD layout). - -Reference math (see transformer_engine.jax.indexer._indexer_impl_reference): - - H_q = einsum("bhtd,dij->bhtij", Q, W_q) # (B, H, T_t, I, d_i) - H_k = K @ W_k # (B, H, T_s, d_i) - H = relu(einsum("bhtij,bhsj->bhtis", H_q, H_k)) # (B, H, T_t, I, T_s) - O = einsum("bhtis,bhti->bhts", H, weights) # (B, H, T_t, T_s) - -``weights`` is precomputed per-(token, indexer-head) weighting with shape -matching Q's leading dims (B, H, T_t, I). It plays the role DeepSeek's -``weights_proj(x)`` plays in their lightning-indexer: a learned, data- -dependent per-head weight, not a static parameter. - -The launcher (``indexer_fused`` at the bottom of this file) wires up grid -size, BlockSpecs, and shape inference. The kernel body itself -(``_indexer_pallas_kernel_body``) is the part you fill in. -""" - -import jax -import jax.numpy as jnp -from jax.experimental import pallas as pl - - -# Preferred tile sizes (used when the LDS budget allows). Pallas-Triton stages -# Q and K tiles to LDS via BlockSpec, so the dominant LDS cost is -# (BLOCK_T + BLOCK_S) * d * dtype_bytes. Auto-shrink keeps that under -# _LDS_BUDGET_BYTES per workgroup. -_PREFERRED_BLOCK_T = 128 -_PREFERRED_BLOCK_S = 64 - -# gfx950 reports 160 KB LDS; leave ~10% headroom for compiler-staged -# intermediates (Hk, Hq_i, accumulators). -_LDS_BUDGET_BYTES = 144 * 1024 -_MIN_BLOCK = 16 # Triton requires matmul dims >= 16 - -_FP8_DTYPES = frozenset([ - jnp.dtype("float8_e4m3fn"), - jnp.dtype("float8_e5m2"), - jnp.dtype("float8_e4m3fnuz"), - jnp.dtype("float8_e5m2fnuz"), -]) - - -def _is_fp8_dtype(dt): - return jnp.dtype(dt) in _FP8_DTYPES - - -def _pick_tiles(T_t, T_s, d, dtype): - """Return (BLOCK_T, BLOCK_S) that fit the LDS budget for Q+K staging. - - Halve BLOCK_T first (less critical for inner-loop reuse), then BLOCK_S. - """ - elem_bytes = jnp.dtype(dtype).itemsize - bt = min(_PREFERRED_BLOCK_T, T_t) - bs = min(_PREFERRED_BLOCK_S, T_s) - - def cost(bt, bs): - return (bt + bs) * d * elem_bytes - - while cost(bt, bs) > _LDS_BUDGET_BYTES and bt > _MIN_BLOCK: - bt //= 2 - while cost(bt, bs) > _LDS_BUDGET_BYTES and bs > _MIN_BLOCK: - bs //= 2 - return bt, bs - - -def _estimate_lds_bytes(BLOCK_T, BLOCK_S, d, d_i, dtype): - """Worst-case LDS estimate. The dominant cost is usually K_tile + W_q[i] - slice (Pallas-Triton stages the per-iteration W_q[i] of shape (d, d_i)). - """ - elem_bytes = jnp.dtype(dtype).itemsize - k_tile = BLOCK_S * d * elem_bytes - q_tile = BLOCK_T * d * elem_bytes - w_q_slice = d * d_i * elem_bytes - # The two pairs that have actually been observed empirically: - return max(k_tile + w_q_slice, q_tile + w_q_slice) - - -class PallasIndexerInfeasible(RuntimeError): - """Raised when no valid (BLOCK_T, BLOCK_S) fits the LDS budget for the - given (d, d_i, dtype). The W_q[i] slice (size d*d_i*dtype_bytes) is the - typical culprit; mitigation requires d-tiling the inner matmul.""" - - -def _dot_fp32(a, b): - """`jnp.dot` with the fp32 accumulator made explicit. - - Without `preferred_element_type`, JAX promotion picks the input dtype as - the dot output dtype. For FP8 inputs that means the accumulated dot is - clamped to fp8 max (~448) BEFORE the fp32 cast — so any real workload - silently saturates. Force fp32 accumulation everywhere. - """ - return jax.lax.dot_general( - a, b, (((a.ndim - 1,), (0,)), ((), ())), - preferred_element_type=jnp.float32, - ) - - -def _make_kernel_body(BLOCK_T, BLOCK_S, d, I, d_i, is_fp8): - """Closure that bakes the static shape constants into the kernel body. - - Pallas kernel bodies trace under jit, so values referenced by Python-level - control flow (``range(I)`` etc.) must be static. The simplest way to make - them static is to capture them in a closure here. - - For FP8 inputs, the outer two dots (K@W_k, Q@W_q[i]) consume FP8 directly - via _dot_fp32; their fp32 outputs are downcast to bf16 for the inner - (Hq_i @ Hk^T) matmul. The combined per-tensor dequant scale is applied - to the fp32 accumulator at the very end; ReLU commutes with positive - scaling so this is exact. - """ - inter_dtype = jnp.bfloat16 if is_fp8 else None # None = preserve dtype - - def _indexer_pallas_kernel_body( - Q_ref, # (1, 1, BLOCK_T, d) - one (b, h) slice, T_t-tile - K_ref, # (1, 1, BLOCK_S, d) - one (b, h) slice, T_s-tile - W_q_ref, # (I, d, d_i) - whole tensor, replicated - W_k_ref, # (d, d_i) - whole tensor, replicated - weights_ref, # (1, 1, BLOCK_T, I) - one (b, h) slice, T_t-tile - scale_ref, # (1,) - combined fp32 scale - O_ref, # (1, 1, BLOCK_T, BLOCK_S) - one tile of the output - ): - """ - Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h). - """ - Q = Q_ref[0, 0] # (BLOCK_T, d) - K = K_ref[0, 0] # (BLOCK_S, d) - Wk = W_k_ref[...] # (d, d_i) - Hk = _dot_fp32(K, Wk) # (BLOCK_S, d_i) - if inter_dtype is not None: - Hk = Hk.astype(inter_dtype) - - acc = jnp.zeros((BLOCK_T, BLOCK_S), dtype=jnp.float32) - for i in range(I): - Wq_i = W_q_ref[i] # (d, d_i) - Hq_i = _dot_fp32(Q, Wq_i) # (BLOCK_T, d_i) - if inter_dtype is not None: - Hq_i = Hq_i.astype(inter_dtype) - Hi = jax.nn.relu(_dot_fp32(Hq_i, Hk.T)) # (BLOCK_T, BLOCK_S) - w_i = weights_ref[0, 0, :, i] # (BLOCK_T,) - acc = acc + Hi * w_i[:, None] - - acc = acc * scale_ref[0] - O_ref[0, 0] = acc.astype(O_ref.dtype) - - return _indexer_pallas_kernel_body - - -def indexer_fused( - Q, K, W_q, W_k, weights, - *, - scale_q=None, scale_k=None, scale_wq=None, scale_wk=None, - out_dtype=None, -): - """Pallas-backed fused indexer. Strict BHSD. - - Args: - Q: (B, H, T_t, d) high-precision (bf16/fp32) or FP8 e4m3 - K: (B, H, T_s, d) must match Q's dtype - W_q: (I, d, d_i) must match Q's dtype - W_k: (d, d_i) must match Q's dtype - weights: (B, H, T_t, I) high-precision regardless of Q dtype - scale_q, scale_k, scale_wq, scale_wk: - per-tensor fp32 dequant scales. Required when Q is FP8. - out_dtype: defaults to Q.dtype for non-FP8, weights.dtype for FP8. - - Returns: - O: (B, H, T_t, T_s) - """ - if Q.ndim != 4 or K.ndim != 4 or weights.ndim != 4: - raise ValueError( - f"indexer_fused (pallas) expects rank-4 BHSD Q, K and weights. Got " - f"Q.shape={Q.shape}, K.shape={K.shape}, weights.shape={weights.shape}. " - "Reshape (or add singleton head/batch axes) before calling the fused path." - ) - - B, H, T_t, d = Q.shape - Bk, Hk, T_s, dk = K.shape - I, d2, d_i = W_q.shape - d3, d_i_k = W_k.shape - Bw, Hw, T_t_w, I_w = weights.shape - - if (Bk, Hk) != (B, H): - raise ValueError(f"(B,H) mismatch: Q has {(B, H)}, K has {(Bk, Hk)}") - if not (d == dk == d2 == d3): - raise ValueError(f"d mismatch across Q/K/W_q/W_k: {d}, {dk}, {d2}, {d3}") - if d_i != d_i_k: - raise ValueError(f"d_i mismatch: W_q has {d_i}, W_k has {d_i_k}") - if (Bw, Hw, T_t_w, I_w) != (B, H, T_t, I): - raise ValueError( - f"weights shape {weights.shape} does not match expected " - f"(B={B}, H={H}, T_t={T_t}, I={I})" - ) - - is_fp8 = _is_fp8_dtype(Q.dtype) - if is_fp8: - for nm, t in (("K", K), ("W_q", W_q), ("W_k", W_k)): - if t.dtype != Q.dtype: - raise ValueError( - f"FP8 mode requires Q/K/W_q/W_k all match dtype; " - f"Q is {Q.dtype} but {nm} is {t.dtype}." - ) - if any(s is None for s in (scale_q, scale_k, scale_wq, scale_wk)): - raise ValueError( - "FP8 mode requires scale_q, scale_k, scale_wq, scale_wk." - ) - scale_combined = jnp.asarray( - jnp.float32(scale_q) * jnp.float32(scale_k) - * jnp.float32(scale_wq) * jnp.float32(scale_wk), - dtype=jnp.float32, - ).reshape((1,)) - if out_dtype is None: - out_dtype = weights.dtype - else: - scale_combined = jnp.asarray(1.0, dtype=jnp.float32).reshape((1,)) - if out_dtype is None: - out_dtype = Q.dtype - - BLOCK_T, BLOCK_S = _pick_tiles(T_t, T_s, d, Q.dtype) - lds = _estimate_lds_bytes(BLOCK_T, BLOCK_S, d, d_i, Q.dtype) - if lds > _LDS_BUDGET_BYTES: - raise PallasIndexerInfeasible( - f"Pallas indexer infeasible for this config: estimated LDS " - f"{lds // 1024} KB > budget {_LDS_BUDGET_BYTES // 1024} KB. " - f"Dominant cost is W_q[i] slice = d*d_i*dtype = " - f"{d * d_i * jnp.dtype(Q.dtype).itemsize // 1024} KB. " - f"Mitigation: d-tile the inner matmul (not implemented). " - f"For this config use the Triton backend instead." - ) - - grid = (B * H, pl.cdiv(T_t, BLOCK_T), pl.cdiv(T_s, BLOCK_S)) - - # BlockSpecs: each input/output is sliced based on (program_id_0, - # program_id_1, program_id_2). index_map returns the *block index* per - # axis (Pallas multiplies by block_shape internally). - def q_idx(bh, tt, ts): return (bh // H, bh % H, tt, 0) - def k_idx(bh, tt, ts): return (bh // H, bh % H, ts, 0) - def wq_idx(bh, tt, ts): return (0, 0, 0) - def wk_idx(bh, tt, ts): return (0, 0) - def weights_idx(bh, tt, ts): return (bh // H, bh % H, tt, 0) - def scale_idx(bh, tt, ts): return (0,) - def o_idx(bh, tt, ts): return (bh // H, bh % H, tt, ts) - - in_specs = [ - pl.BlockSpec(block_shape=(1, 1, BLOCK_T, d), index_map=q_idx), - pl.BlockSpec(block_shape=(1, 1, BLOCK_S, d), index_map=k_idx), - pl.BlockSpec(block_shape=(I, d, d_i), index_map=wq_idx), - pl.BlockSpec(block_shape=(d, d_i), index_map=wk_idx), - pl.BlockSpec(block_shape=(1, 1, BLOCK_T, I), index_map=weights_idx), - pl.BlockSpec(block_shape=(1,), index_map=scale_idx), - ] - out_spec = pl.BlockSpec( - block_shape=(1, 1, BLOCK_T, BLOCK_S), - index_map=o_idx, - ) - out_shape = jax.ShapeDtypeStruct((B, H, T_t, T_s), out_dtype) - - kernel_body = _make_kernel_body(BLOCK_T, BLOCK_S, d, I, d_i, is_fp8) - - return pl.pallas_call( - kernel_body, - grid=grid, - in_specs=in_specs, - out_specs=out_spec, - out_shape=out_shape, - )(Q, K, W_q, W_k, weights, scale_combined) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index 79a4dd733..1a9c517a2 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -58,8 +58,4 @@ def lowering(ctx, x, **kwargs): from .utils import * from .permutation import * -from .indexer import ( - indexer_fused_triton, - indexer_fused_topk_triton, - score_reduce_triton, -) +from .indexer import score_reduce_triton diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index 322530978..e16838cd1 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -2,31 +2,23 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Raw-Triton low-rank indexer kernel + JAX primitive. - -Math (matches the reference in transformer_engine.jax.indexer): - - C_q = Q @ W_dq # (..., T, d_c) - H_q = einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) - H_k = K @ W_k # (..., S, d_i) - H = relu(einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) - O = einsum("...ths,...ht->...ts", H, weights) # (..., T, S) - -Q is the hidden state (rank-4 BHSD: B × outer-H × T × d). W_dq is a low-rank -down-projection (d → d_c) and W_uq is the per-(indexer-head) up-projection. -The kernel loops over indexer heads internally; the outer (B, outer-H) dims -are flattened into the grid's first axis. - -FP8 mode: Q / K / W_uq / W_dq / W_k are all FP8 e4m3 (same dtype). The -five per-tensor scales (scale_q, scale_k, scale_wq, scale_wd, scale_wk) -fold into a single fp32 scalar applied at the end (ReLU is scale-invariant -under positive scaling). Three intermediate amax-based re-quantizations -(Cq, Hk, Hq per-head) keep the inner matmuls in fp8 too. +"""Triton score-relu-reduce kernel for the lightning-indexer hybrid backend. + +The hybrid backend computes the four projections (C_q, H_q, H_k, W_o) via +``jnp.einsum`` (which lowers to hipBLASLt bf16 GEMMs) and then hands the +results to this kernel for the score matmul + ReLU + per-(t, h) weighted +H-reduction: + + scores = relu(einsum("...thi,...si->...ths", H_q, H_k)) # never written + O = einsum("...ths,...th->...ts", scores, W_o) + +The kernel keeps each per-head score tile in registers, avoiding the +(B, oH, T, H, S) HBM round-trip that an einsum-only implementation pays +on the pre-relu score tensor. """ import functools -import jax import jax.numpy as jnp import triton import triton.language as tl @@ -38,372 +30,6 @@ from .utils import triton_call_lowering -# Autotune sweep: BLOCK_T × BLOCK_S × num_warps × num_stages. Profiling -# showed num_warps=4 with the prior default (BLOCK_T=128) saturated VGPR -# (256/thread), forcing 1 wave/SIMD; smaller tiles or num_warps=8 cut VGPR -# in half and gave a 6× speedup at the d=512 fp8 config. Each config below -# launches at its own grid (cdiv(T_t, BLOCK_T) × cdiv(T_s, BLOCK_S)) — the -# triton_call_lowering helper supports per-config grids via a callable -# `grid` argument. -def _autotune_configs(): - configs = [] - for block_t in (16, 32): - for block_s in (16, 32): - for block_d in (16, 32): - for num_warps in (4, 8): - for num_stages in (1, 2): - configs.append(triton.Config( - {"BLOCK_T": block_t, "BLOCK_S": block_s, - "BLOCK_D": block_d}, - num_warps=num_warps, num_stages=num_stages, - )) - return configs - -_AUTOTUNE_CONFIGS = _autotune_configs() -# Re-run the benchmark when any of these constexprs change. T_t/T_s only -# affect grid size; their optimal config is dominated by per-CTA shape and -# the precision (IS_FP8). -_AUTOTUNE_KEY = ["IS_FP8", "d", "d_c", "H", "d_i"] - - -# Max representable value of FP8 e4m3 (used for per-tile inter-quantization). -# Triton requires module-level constants referenced inside @jit kernels to be -# wrapped in tl.constexpr explicitly. -_FP8_E4M3_MAX = tl.constexpr(448.0) -# Floor on per-tile amax to avoid divide-by-zero when a tile is all-zero. -_FP8_AMAX_EPS = tl.constexpr(1e-30) - - -@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=_AUTOTUNE_KEY) -@triton.jit -def _indexer_kernel( - Q_ptr, - K_ptr, - W_uq_ptr, # (H, d_c, d_i) - replicated across (B, oH) - W_dq_ptr, # (d, d_c) - replicated; same dtype as Q - W_k_ptr, # (d, d_i) - replicated - weights_ptr, - scale_ptr, # 0-D fp32 tensor: combined scale sq*sk*swq*swk (1.0 if non-FP8) - O_ptr, - B: tl.constexpr, - oH: tl.constexpr, - T_t: tl.constexpr, - T_s: tl.constexpr, - d: tl.constexpr, - d_c: tl.constexpr, - H: tl.constexpr, - d_i: tl.constexpr, - BLOCK_T: tl.constexpr, - BLOCK_S: tl.constexpr, - BLOCK_D: tl.constexpr, - IS_FP8: tl.constexpr, -): - """Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h_outer) slice. - - Grid: (B * oH, cdiv(T_t, BLOCK_T), cdiv(T_s, BLOCK_S)) - - Pipeline: - C_q = Q @ W_dq (down-projection, d-tiled) - Hk = K @ W_k (key projection, d-tiled) - for h in range(H): (loop over indexer heads) - Hq = C_q @ W_uq[:, h, :] (per-head up-projection) - Hi = relu(Hq @ Hk^T) (per-head score) - acc += Hi * weights[:, h] (weighted accumulate) - - The two d-contracting GEMMs (Q@W_dq and K@W_k) are tiled along d in - chunks of BLOCK_D. This keeps the W_dq / W_k tiles loaded into LDS at - BLOCK_D × {d_c, d_i} instead of d × {d_c, d_i}, freeing registers / - LDS for the inner per-head loop. - - FP8 mode (IS_FP8=True): all five matrices share the fp8 dtype. Every - MFMA is native fp8: the d-tiled Q@W_dq and K@W_k dots, then the inner - C_q@W_uq[h] and Hq@Hk^T dots after per-tile amax re-quantization of - Cq/Hk/Hq. The per-tile amax scales fold into the accumulator (Hq inside - the loop, Cq/Hk after) along with the user's combined per-tensor scale. - """ - pid_bh = tl.program_id(0) - pid_t = tl.program_id(1) - pid_s = tl.program_id(2) - - b = pid_bh // oH - h_outer = pid_bh % oH - - rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) - rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) - rdc = tl.arange(0, d_c) - rdi = tl.arange(0, d_i) - - rt_mask = rt < T_t - rs_mask = rs < T_s - - in_dtype = Q_ptr.dtype.element_ty - q_base = b * (oH * T_t * d) + h_outer * (T_t * d) - k_base = b * (oH * T_s * d) + h_outer * (T_s * d) - - # d-tiled accumulators for Q @ W_dq → (BLOCK_T, d_c) and K @ W_k → - # (BLOCK_S, d_i). fp32 accumulators; quantization happens after the loop. - # Requires d % BLOCK_D == 0. - Cq_dot = tl.zeros((BLOCK_T, d_c), dtype=tl.float32) - Hk_dot = tl.zeros((BLOCK_S, d_i), dtype=tl.float32) - for d_off in range(0, d, BLOCK_D): - rd = d_off + tl.arange(0, BLOCK_D) - - q_ptrs = Q_ptr + q_base + rt[:, None] * d + rd[None, :] - Q_chunk = tl.load(q_ptrs, mask=rt_mask[:, None], other=0.0) - k_ptrs = K_ptr + k_base + rs[:, None] * d + rd[None, :] - K_chunk = tl.load(k_ptrs, mask=rs_mask[:, None], other=0.0) - wdq_ptrs = W_dq_ptr + rd[:, None] * d_c + rdc[None, :] - Wdq_chunk = tl.load(wdq_ptrs) - wk_ptrs = W_k_ptr + rd[:, None] * d_i + rdi[None, :] - Wk_chunk = tl.load(wk_ptrs) - - Cq_dot = tl.dot(Q_chunk, Wdq_chunk, acc=Cq_dot) - Hk_dot = tl.dot(K_chunk, Wk_chunk, acc=Hk_dot) - - # Quantize Cq and Hk for the inner up-projection. - if IS_FP8: - Cq_amax = tl.maximum(tl.max(tl.abs(Cq_dot)), _FP8_AMAX_EPS) - Cq_inter = Cq_amax / _FP8_E4M3_MAX - C_q = (Cq_dot / Cq_inter).to(in_dtype) - Hk_amax = tl.maximum(tl.max(tl.abs(Hk_dot)), _FP8_AMAX_EPS) - Hk_inter = Hk_amax / _FP8_E4M3_MAX - Hk_T = tl.trans((Hk_dot / Hk_inter).to(in_dtype)) - else: - C_q = Cq_dot.to(in_dtype) - Hk_T = tl.trans(Hk_dot.to(in_dtype)) - Cq_inter = 1.0 - Hk_inter = 1.0 - - acc = tl.zeros((BLOCK_T, BLOCK_S), dtype=tl.float32) - - w_base = b * (oH * H * T_t) + h_outer * (H * T_t) - for h_idx in range(H): - # W_uq[h_idx, :, :] is a contiguous (d_c, d_i) block of W_uq (H, d_c, d_i). - wuq_ptrs = W_uq_ptr + h_idx * (d_c * d_i) + rdc[:, None] * d_i + rdi[None, :] - Wuq_h = tl.load(wuq_ptrs) - - # Hq = C_q @ W_uq[h_idx]: (BLOCK_T, d_i) - Hq_dot = tl.dot(C_q, Wuq_h) - if IS_FP8: - Hq_amax = tl.maximum(tl.max(tl.abs(Hq_dot)), _FP8_AMAX_EPS) - Hq_inter = Hq_amax / _FP8_E4M3_MAX - Hq_h = (Hq_dot / Hq_inter).to(in_dtype) - else: - Hq_h = Hq_dot.to(in_dtype) - Hq_inter = 1.0 - - # Hi = relu(Hq @ Hk^T): (BLOCK_T, BLOCK_S). FP8 MFMA in FP8 mode. - Hi_raw = tl.dot(Hq_h, Hk_T) - Hi = tl.maximum(Hi_raw, 0.0) - - # weights[b, h_outer, h_idx, t]: contiguous BLOCK_T-vector. - w_ptrs = weights_ptr + w_base + h_idx * T_t + rt - w_i = tl.load(w_ptrs, mask=rt_mask, other=0.0) - - if IS_FP8: - acc += Hi * (Hq_inter * w_i)[:, None] - else: - acc += Hi * w_i[:, None] - - # Apply combined per-tensor scale + carried-out intermediate scales. - scale = tl.load(scale_ptr) - if IS_FP8: - acc = acc * (scale * Cq_inter * Hk_inter) - else: - acc = acc * scale - - # Store O tile: (BLOCK_T, BLOCK_S). O has shape (B, oH, T, S). - o_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) - o_ptrs = O_ptr + o_base + rt[:, None] * T_s + rs[None, :] - tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), - mask=rt_mask[:, None] & rs_mask[None, :]) - - -# --- JAX primitive --------------------------------------------------------------- - -_indexer_p = extend_core.Primitive("te_indexer_triton") -_indexer_p.multiple_results = True - - -_FP8_DTYPES = frozenset([ - jnp.dtype("float8_e4m3fn"), - jnp.dtype("float8_e5m2"), - jnp.dtype("float8_e4m3fnuz"), - jnp.dtype("float8_e5m2fnuz"), -]) - - -def _is_fp8_dtype(dt): - return jnp.dtype(dt) in _FP8_DTYPES - - -@_indexer_p.def_abstract_eval -def _indexer_abstract(Q, K, W_uq, W_dq, W_k, weights, scale, *, out_dtype): - del W_uq, W_dq, W_k, weights, scale - B, oH, T_t, _ = Q.shape - _, _, T_s, _ = K.shape - return [core.ShapedArray((B, oH, T_t, T_s), out_dtype)] - - -_indexer_p.def_impl(functools.partial(xla.apply_primitive, _indexer_p)) - - -def _indexer_lowering(ctx, Q, K, W_uq, W_dq, W_k, weights, scale, *, out_dtype): - del out_dtype # baked into the output aval - Q_aval = ctx.avals_in[0] - K_aval = ctx.avals_in[1] - W_uq_aval = ctx.avals_in[2] - B, oH, T_t, d = Q_aval.shape - T_s = K_aval.shape[2] - H, d_c, d_i = W_uq_aval.shape - - is_fp8 = _is_fp8_dtype(Q_aval.dtype) - - # Per-config grid: BLOCK_T/BLOCK_S come from the autotuned config kwargs - # (or fall back to a sensible default if autotune is not active). - def grid_fn(merged_kwargs): - bt = merged_kwargs.get("BLOCK_T", 128) - bs = merged_kwargs.get("BLOCK_S", 64) - return (B * oH, triton.cdiv(T_t, bt), triton.cdiv(T_s, bs)) - - return triton_call_lowering( - ctx, - _indexer_kernel, - Q, - K, - W_uq, - W_dq, - W_k, - weights, - scale, - grid=grid_fn, - num_warps=4, - num_stages=1, - constexprs={ - "B": B, - "oH": oH, - "T_t": T_t, - "T_s": T_s, - "d": d, - "d_c": d_c, - "H": H, - "d_i": d_i, - "IS_FP8": is_fp8, - }, - ) - - -mlir.register_lowering(_indexer_p, _indexer_lowering, platform="rocm") -mlir.register_lowering(_indexer_p, _indexer_lowering, platform="cuda") - - -def indexer_fused_triton( - Q, - K, - W_uq, - W_dq, - W_k, - weights, - *, - scale_q=None, - scale_k=None, - scale_wq=None, - scale_wd=None, - scale_wk=None, - out_dtype=None, -): - """Raw-Triton low-rank indexer (BHSD). - - Args: - Q: (B, oH, T, d) high-precision (bf16/fp32) or FP8 e4m3 - K: (B, oH, S, d) must match Q's dtype - W_uq: (H, d_c, d_i) up-projection; must match Q's dtype - W_dq: (d, d_c) down-projection; must match Q's dtype - W_k: (d, d_i) key projection; must match Q's dtype - weights: (B, oH, H, T) high-precision regardless of Q dtype - scale_q, scale_k, scale_wq, scale_wd, scale_wk: - per-tensor fp32 dequant scales. All five required when Q is FP8. - out_dtype: dtype of the output O. Defaults to Q.dtype for non-FP8 and - weights.dtype (typically bf16) for FP8. - - BLOCK_T / BLOCK_S / BLOCK_D / num_warps / num_stages are autotuned at - first invocation per (IS_FP8, d, d_c, H, d_i) key. - - Returns: - O of shape (B, oH, T, S) - """ - if Q.ndim != 4 or K.ndim != 4 or weights.ndim != 4: - raise ValueError( - "indexer_fused_triton expects rank-4 BHSD Q, K, weights. Got " - f"Q.shape={Q.shape}, K.shape={K.shape}, weights.shape={weights.shape}." - ) - B, oH, T_t, d = Q.shape - Bk, oHk, T_s, dk = K.shape - H, d_c_uq, d_i = W_uq.shape - d_dq, d_c_dq = W_dq.shape - d_wk, d_i_wk = W_k.shape - Bw, oHw, Hw, T_w = weights.shape - if (Bk, oHk) != (B, oH): - raise ValueError(f"(B,oH) mismatch: Q has {(B, oH)}, K has {(Bk, oHk)}") - if not (d == dk == d_dq == d_wk): - raise ValueError(f"d mismatch across Q/K/W_dq/W_k: {d}, {dk}, {d_dq}, {d_wk}") - if d_c_uq != d_c_dq: - raise ValueError(f"d_c mismatch: W_uq has {d_c_uq}, W_dq has {d_c_dq}") - if d_i != d_i_wk: - raise ValueError(f"d_i mismatch: W_uq has {d_i}, W_k has {d_i_wk}") - if (Bw, oHw, Hw, T_w) != (B, oH, H, T_t): - raise ValueError( - f"weights shape {weights.shape} does not match expected " - f"(B={B}, oH={oH}, H={H}, T={T_t})" - ) - - is_fp8 = _is_fp8_dtype(Q.dtype) - if is_fp8: - for nm, t in (("K", K), ("W_uq", W_uq), ("W_dq", W_dq), ("W_k", W_k)): - if t.dtype != Q.dtype: - raise ValueError( - f"FP8 mode requires Q/K/W_uq/W_dq/W_k all match dtype; " - f"Q is {Q.dtype} but {nm} is {t.dtype}." - ) - scales = (scale_q, scale_k, scale_wq, scale_wd, scale_wk) - if any(s is None for s in scales): - raise ValueError( - "FP8 mode requires scale_q, scale_k, scale_wq, scale_wd, scale_wk." - ) - scale_combined = jnp.asarray( - jnp.float32(scale_q) * jnp.float32(scale_k) - * jnp.float32(scale_wq) * jnp.float32(scale_wd) - * jnp.float32(scale_wk), - dtype=jnp.float32, - ) - if out_dtype is None: - out_dtype = weights.dtype - else: - scale_combined = jnp.asarray(1.0, dtype=jnp.float32) - if out_dtype is None: - out_dtype = Q.dtype - - return _indexer_p.bind( - Q, - K, - W_uq, - W_dq, - W_k, - weights, - scale_combined, - out_dtype=jnp.dtype(out_dtype), - )[0] - - -# --- Score+ReLU+H-reduce fused kernel (hybrid backend) ------------------------- -# -# Inputs are *already projected*: Hq, Hk, W_o all come from upstream einsum -# calls (hipBLASLt). This kernel does only the score matmul, the relu, and the -# per-token-per-head weighted sum over H — the pieces that have no efficient -# einsum/HLO equivalent because they'd require materializing the (B, oH, T, H, S) -# pre-relu score tensor in HBM. By fusing them in registers, we eliminate that -# round-trip entirely. - def _score_reduce_autotune_configs(): # The kernel is dominated by Hq reads (one (BLOCK_T, d_i) load per H # iteration). Bigger BLOCK_T ⇒ fewer T tiles ⇒ less total Hq traffic. @@ -430,13 +56,12 @@ def _score_reduce_autotune_configs(): return cfgs -@triton.autotune(configs=_score_reduce_autotune_configs(), key=["H", "d_i", "IS_FP8"]) +@triton.autotune(configs=_score_reduce_autotune_configs(), key=["H", "d_i"]) @triton.jit def _score_reduce_kernel( - Hq_ptr, # (B, oH, T_t, H, d_i) bf16 OR fp8 e4m3 - Hk_ptr, # (B, oH, T_s, d_i) same dtype as Hq - W_o_ptr, # (B, oH, T_t, H) bf16 always - scale_ptr, # 0-D fp32: combined scale_hq * scale_hk (1.0 in bf16 mode) + Hq_ptr, # (B, oH, T_t, H, d_i) — produced by einsum("...tc,hci->...thi") + Hk_ptr, # (B, oH, T_s, d_i) + W_o_ptr, # (B, oH, T_t, H) O_ptr, # (B, oH, T_t, T_s) B: tl.constexpr, oH: tl.constexpr, @@ -446,7 +71,6 @@ def _score_reduce_kernel( d_i: tl.constexpr, BLOCK_T: tl.constexpr, BLOCK_S: tl.constexpr, - IS_FP8: tl.constexpr, ): """Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h_outer) slice. @@ -456,10 +80,6 @@ def _score_reduce_kernel( and vary only in S — they all read the same per-head Hq slab, hitting L2 instead of HBM. Hq layout is the natural einsum output (..., T, H, d_i); per-head loads are strided in T (stride H*d_i). - - FP8 mode (IS_FP8=True): Hq and Hk are e4m3 with per-tensor fp32 scales. - The two scales fold into one fp32 multiply at the end (relu commutes - with positive scaling). The score MFMA runs native fp8-fp8. """ pid_s = tl.program_id(0) pid_t = tl.program_id(1) @@ -495,20 +115,10 @@ def _score_reduce_kernel( wo_ptrs = W_o_ptr + wo_base + rt * H + h w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0) - # tl.dot lowers to native fp8-fp8 MFMA when both inputs are fp8; - # otherwise bf16-bf16 MFMA. Output is fp32 in both cases. score = tl.dot(Hq_h, Hk_T) score = tl.maximum(score, 0.0) acc += score * w_h[:, None].to(tl.float32) - # Apply the combined per-tensor dequant scale at the very end. relu is - # invariant under multiplication by a positive scalar (sq * sk > 0), - # so this is mathematically equivalent to scaling Hq_h and Hk_T per - # iteration but costs one fp32 multiply per output element instead of - # one per dot input. - scale = tl.load(scale_ptr) - acc = acc * scale - o_ptrs = O_ptr + o_base + rt[:, None] * T_s + rs[None, :] tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=rt_mask[:, None] & rs_mask[None, :]) @@ -519,8 +129,8 @@ def _score_reduce_kernel( @_score_reduce_p.def_abstract_eval -def _score_reduce_abstract(Hq, Hk, W_o, scale, *, out_dtype): - del W_o, scale +def _score_reduce_abstract(Hq, Hk, W_o, *, out_dtype): + del W_o # Hq layout: (B, oH, T_t, H, d_i) B, oH, T_t, _H, _d_i = Hq.shape T_s = Hk.shape[2] @@ -530,13 +140,12 @@ def _score_reduce_abstract(Hq, Hk, W_o, scale, *, out_dtype): _score_reduce_p.def_impl(functools.partial(xla.apply_primitive, _score_reduce_p)) -def _score_reduce_lowering(ctx, Hq, Hk, W_o, scale, *, out_dtype): +def _score_reduce_lowering(ctx, Hq, Hk, W_o, *, out_dtype): del out_dtype Hq_aval = ctx.avals_in[0] Hk_aval = ctx.avals_in[1] B, oH, T_t, H, d_i = Hq_aval.shape T_s = Hk_aval.shape[2] - is_fp8 = _is_fp8_dtype(Hq_aval.dtype) def grid_fn(merged_kwargs): bt = merged_kwargs.get("BLOCK_T", 64) @@ -548,7 +157,7 @@ def grid_fn(merged_kwargs): return triton_call_lowering( ctx, _score_reduce_kernel, - Hq, Hk, W_o, scale, + Hq, Hk, W_o, grid=grid_fn, num_warps=4, num_stages=2, @@ -559,7 +168,6 @@ def grid_fn(merged_kwargs): "T_s": T_s, "H": H, "d_i": d_i, - "IS_FP8": is_fp8, }, ) @@ -568,8 +176,7 @@ def grid_fn(merged_kwargs): mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="cuda") -def score_reduce_triton(Hq, Hk, W_o, *, - scale_hq=None, scale_hk=None, out_dtype=None): +def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): """Triton fused score-matmul + relu + per-(t, h) weighted H-reduction. Replaces the pattern: @@ -578,17 +185,13 @@ def score_reduce_triton(Hq, Hk, W_o, *, O = jnp.einsum("...ths,...th->...ts", scores, W_o) with a single kernel that holds the per-head score tile in registers, - avoiding the (B, oH, T, H, S) HBM round-trip that an einsum+XLA chain - pays (the dominant cost in profile_indexer's einsum baseline). + avoiding the (B, oH, T, H, S) HBM round-trip an einsum+XLA chain pays. Args: - Hq: (B, oH, T_t, H, d_i) bf16 OR fp8 e4m3 - Hk: (B, oH, T_s, d_i) must match Hq.dtype - W_o: (B, oH, T_t, H) bf16 - scale_hq, scale_hk: - per-tensor fp32 dequant scales for Hq / Hk. Required when - Hq is FP8; ignored otherwise. - out_dtype: defaults to Hq.dtype (or W_o.dtype in FP8 mode). + Hq: (B, oH, T_t, H, d_i) + Hk: (B, oH, T_s, d_i) + W_o: (B, oH, T_t, H) + out_dtype: defaults to Hq.dtype. Returns: O: (B, oH, T_t, T_s) @@ -621,278 +224,9 @@ def score_reduce_triton(Hq, Hk, W_o, *, f"(B={B}, oH={oH}, T_t={T_t}, H={H})" ) - is_fp8 = _is_fp8_dtype(Hq.dtype) - if is_fp8: - if Hk.dtype != Hq.dtype: - raise ValueError( - f"FP8 mode requires Hk.dtype == Hq.dtype; " - f"Hq is {Hq.dtype} but Hk is {Hk.dtype}." - ) - if scale_hq is None or scale_hk is None: - raise ValueError("FP8 mode requires scale_hq and scale_hk.") - scale_combined = jnp.asarray( - jnp.float32(scale_hq) * jnp.float32(scale_hk), - dtype=jnp.float32, - ) - if out_dtype is None: - out_dtype = W_o.dtype - else: - scale_combined = jnp.asarray(1.0, dtype=jnp.float32) - if out_dtype is None: - out_dtype = Hq.dtype + if out_dtype is None: + out_dtype = Hq.dtype return _score_reduce_p.bind( - Hq, Hk, W_o, scale_combined, out_dtype=jnp.dtype(out_dtype) + Hq, Hk, W_o, out_dtype=jnp.dtype(out_dtype) )[0] - - -# --- Top-K fused variant ------------------------------------------------------- -# -# FlashAttention-style: each CTA owns one (b, h, t-tile) and serializes over s -# tiles, maintaining a running per-row top-k of the score matrix. Output is -# (B, H, T_t, k) values + (B, H, T_t, k) int32 indices — never materializes the -# full (T_t, T_s) score tensor. -# -# Top-k merge: pack (val_bits << 32) | idx_u32 into uint64, build (BLOCK_T, -# k+BLOCK_S) via gather+where (tl.cat is 1D-only on this Triton), sort -# descending, take first k. Constraints: k pow2, k+block_s pow2. -# -# Score values are post-ReLU (≥ 0) so the fp32 bit pattern sorts correctly as -# uint32. Init sentinel: (val=0.0, idx=0xFFFFFFFF) — real positive values -# displace it; rows with fewer than k positive scores trail with idx=-1. - -_DEFAULT_K = 64 - - -@triton.jit -def _indexer_topk_kernel( - Q_ptr, - K_ptr, - W_q_ptr, - W_k_ptr, - weights_ptr, - O_v_ptr, - O_i_ptr, - B: tl.constexpr, - H: tl.constexpr, - T_t: tl.constexpr, - T_s: tl.constexpr, - d: tl.constexpr, - I: tl.constexpr, - d_i: tl.constexpr, - K_TOPK: tl.constexpr, - BLOCK_T: tl.constexpr, - BLOCK_S: tl.constexpr, - KS_SUM: tl.constexpr, -): - pid_bh = tl.program_id(0) - pid_t = tl.program_id(1) - b = pid_bh // H - h = pid_bh % H - - rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) - rd = tl.arange(0, d) - rdi = tl.arange(0, d_i) - rt_mask = rt < T_t - - q_base = b * (H * T_t * d) + h * (T_t * d) - Q = tl.load(Q_ptr + q_base + rt[:, None] * d + rd[None, :], - mask=rt_mask[:, None], other=0.0) - - Wk = tl.load(W_k_ptr + rd[:, None] * d_i + rdi[None, :]) - - running_pack = tl.full((BLOCK_T, K_TOPK), 0xFFFFFFFF, dtype=tl.uint64) - - w_base = b * (H * T_t * I) + h * (T_t * I) - k_base = b * (H * T_s * d) + h * (T_s * d) - - n_s_tiles = T_s // BLOCK_S - for s_idx in range(n_s_tiles): - s_off = s_idx * BLOCK_S - rs = s_off + tl.arange(0, BLOCK_S) - - Kt = tl.load(K_ptr + k_base + rs[:, None] * d + rd[None, :]) - Hk = tl.dot(Kt, Wk).to(Q.dtype) - - acc = tl.zeros((BLOCK_T, BLOCK_S), dtype=tl.float32) - for i in range(I): - Wq_i = tl.load(W_q_ptr + i * (d * d_i) + rd[:, None] * d_i + rdi[None, :]) - Hq_i = tl.dot(Q, Wq_i).to(Q.dtype) - Hi_raw = tl.dot(Hq_i, tl.trans(Hk)) - Hi = tl.maximum(Hi_raw, 0.0) - w_i = tl.load(weights_ptr + w_base + rt * I + i, mask=rt_mask, other=0.0) - acc += Hi * w_i[:, None] - - # Encode fp32 -> monotonic uint32 (radix-sort fp32 trick) so negative - # acc values sort below positive ones. - acc_bits = acc.to(tl.uint32, bitcast=True) - acc_sext = (acc.to(tl.int32, bitcast=True) >> 31).to(tl.uint32) - enc_mask = acc_sext | tl.cast(0x80000000, tl.uint32) - acc_key = acc_bits ^ enc_mask # (BLOCK_T, BLOCK_S) u32 - tile_idx = rs.to(tl.uint32) - tile_v_u = acc_key.to(tl.uint64) - tile_i_u = tile_idx.to(tl.uint64) - tile_pack = (tile_v_u << 32) | tile_i_u[None, :].broadcast_to((BLOCK_T, BLOCK_S)) - - pos = tl.arange(0, KS_SUM) - r_idx = tl.minimum(pos, K_TOPK - 1) - t_idx = tl.maximum(pos.to(tl.int32) - K_TOPK, 0).to(tl.int32) - r_ext = tl.gather(running_pack, r_idx[None, :].broadcast_to((BLOCK_T, KS_SUM)), axis=1) - t_ext = tl.gather(tile_pack, t_idx[None, :].broadcast_to((BLOCK_T, KS_SUM)), axis=1) - combined = tl.where((pos < K_TOPK)[None, :].broadcast_to((BLOCK_T, KS_SUM)), - r_ext, t_ext) - - running_pack = tl.topk(combined, K_TOPK, dim=1) - - # Decode monotonic uint32 key -> fp32 bits. - out_key = (running_pack >> 32).to(tl.uint32) - out_key_sext = (~out_key.to(tl.int32, bitcast=True) >> 31).to(tl.uint32) - dec_mask = out_key_sext | tl.cast(0x80000000, tl.uint32) - out_bits = out_key ^ dec_mask - out_vals_fp32 = out_bits.to(tl.float32, bitcast=True) - out_idxs = (running_pack & 0xFFFFFFFF).to(tl.uint32).to(tl.int32) - - rk = tl.arange(0, K_TOPK) - o_base = b * (H * T_t * K_TOPK) + h * (T_t * K_TOPK) - tl.store(O_v_ptr + o_base + rt[:, None] * K_TOPK + rk[None, :], - out_vals_fp32.to(O_v_ptr.dtype.element_ty), - mask=rt_mask[:, None]) - tl.store(O_i_ptr + o_base + rt[:, None] * K_TOPK + rk[None, :], - out_idxs, mask=rt_mask[:, None]) - - -_indexer_topk_p = extend_core.Primitive("te_indexer_topk_triton") -_indexer_topk_p.multiple_results = True - - -@_indexer_topk_p.def_abstract_eval -def _indexer_topk_abstract(Q, K, W_q, W_k, weights, *, - k, block_t, block_s, num_warps, num_stages): - del W_q, W_k, weights, block_t, block_s, num_warps, num_stages - B, H, T_t, _ = Q.shape - return [ - core.ShapedArray((B, H, T_t, k), Q.dtype), - core.ShapedArray((B, H, T_t, k), jnp.int32), - ] - - -_indexer_topk_p.def_impl(functools.partial(xla.apply_primitive, _indexer_topk_p)) - - -def _indexer_topk_lowering(ctx, Q, K, W_q, W_k, weights, *, - k, block_t, block_s, num_warps, num_stages): - Q_aval = ctx.avals_in[0] - K_aval = ctx.avals_in[1] - W_q_aval = ctx.avals_in[2] - B, H, T_t, d = Q_aval.shape - T_s = K_aval.shape[2] - I, _, d_i = W_q_aval.shape - - grid = (B * H, triton.cdiv(T_t, block_t)) - - return triton_call_lowering( - ctx, - _indexer_topk_kernel, - Q, - K, - W_q, - W_k, - weights, - grid=grid, - num_warps=num_warps, - num_stages=num_stages, - constexprs={ - "B": B, - "H": H, - "T_t": T_t, - "T_s": T_s, - "d": d, - "I": I, - "d_i": d_i, - "K_TOPK": k, - "BLOCK_T": block_t, - "BLOCK_S": block_s, - "KS_SUM": k + block_s, - }, - ) - - -mlir.register_lowering(_indexer_topk_p, _indexer_topk_lowering, platform="rocm") -mlir.register_lowering(_indexer_topk_p, _indexer_topk_lowering, platform="cuda") - - -def _is_pow2(n): - return n > 0 and (n & (n - 1)) == 0 - - -def indexer_fused_topk_triton( - Q, - K, - W_q, - W_k, - weights, - *, - k: int = _DEFAULT_K, - block_t: int = 128, - block_s: int = 64, - num_warps: int = 4, - num_stages: int = 1, -): - """Fused indexer + per-row top-k along T_s. Returns (vals, idxs). - - vals: (B, H, T_t, k) Q.dtype — descending top-k post-ReLU scores - idxs: (B, H, T_t, k) int32 — corresponding s positions in [0, T_s) - - Constraints: - * Q, K, weights are rank-4 BHSD. - * T_s % block_s == 0 (no masking inside inner loop). - * k and (k + block_s) are powers of 2 (tl.sort and tl.arange). - """ - if Q.ndim != 4 or K.ndim != 4 or weights.ndim != 4: - raise ValueError( - "indexer_fused_topk_triton expects rank-4 BHSD Q, K, weights. Got " - f"Q.shape={Q.shape}, K.shape={K.shape}, weights.shape={weights.shape}." - ) - B, H, T_t, d = Q.shape - Bk, Hk, T_s, dk = K.shape - I, d2, d_i = W_q.shape - d3, d_i_k = W_k.shape - Bw, Hw, T_t_w, I_w = weights.shape - if (Bk, Hk) != (B, H): - raise ValueError(f"(B,H) mismatch: Q has {(B, H)}, K has {(Bk, Hk)}") - if not (d == dk == d2 == d3): - raise ValueError(f"d mismatch across Q/K/W_q/W_k: {d}, {dk}, {d2}, {d3}") - if d_i != d_i_k: - raise ValueError(f"d_i mismatch: W_q has {d_i}, W_k has {d_i_k}") - if (Bw, Hw, T_t_w, I_w) != (B, H, T_t, I): - raise ValueError( - f"weights shape {weights.shape} does not match expected " - f"(B={B}, H={H}, T_t={T_t}, I={I})" - ) - if k > T_s: - raise ValueError(f"k={k} exceeds T_s={T_s}") - - block_t = min(block_t, T_t) - block_s = min(block_s, T_s) - - if T_s % block_s != 0: - raise ValueError( - f"T_s={T_s} must be divisible by block_s={block_s} (kernel doesn't " - "mask invalid s positions in the inner loop)." - ) - if not _is_pow2(k): - raise ValueError(f"k={k} must be a power of 2 (tl.arange requirement)") - if not _is_pow2(k + block_s): - raise ValueError( - f"k + block_s = {k + block_s} must be a power of 2 " - f"(k={k}, block_s={block_s})" - ) - - return _indexer_topk_p.bind( - Q, K, W_q, W_k, weights, - k=k, - block_t=block_t, - block_s=block_s, - num_warps=num_warps, - num_stages=num_stages, - ) From 9335ef0776f29b120ae64962ce134cf42b14c775 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 11 May 2026 17:59:52 +0000 Subject: [PATCH 04/17] Updated benchmarks --- benchmarks/profile_indexer.py | 11 +- benchmarks/profile_indexer_topk.py | 312 ++++++++++++----------------- benchmarks/run_indexer_kernel.py | 117 ----------- 3 files changed, 128 insertions(+), 312 deletions(-) delete mode 100644 benchmarks/run_indexer_kernel.py diff --git a/benchmarks/profile_indexer.py b/benchmarks/profile_indexer.py index f0e019305..8786be47c 100644 --- a/benchmarks/profile_indexer.py +++ b/benchmarks/profile_indexer.py @@ -108,12 +108,7 @@ def _is_fp8(dt): def _bind_scales(fn, scales, *, backend=None): - """Return a 6-arg jit-able function that internally adds scale kwargs. - - If ``backend`` is given, it is forwarded as a kwarg to ``fn`` (used to - select between einsum / hybrid / pure-triton via the same ``indexer`` - entry point). - """ + """Return a 6-arg jit-able function that internally adds scale kwargs.""" extra = {} if backend is not None: extra["backend"] = backend @@ -133,7 +128,9 @@ def _build_impls(scales): ("baseline", _bind_scales(indexer, scales, backend="reference")), ] if _HAVE_HYBRID: - impls.append(("hybrid", _bind_scales(indexer, scales, backend="hybrid"))) + impls.append( + ("hybrid", _bind_scales(indexer, scales, backend="hybrid")) + ) return impls diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py index bc1e27d8d..1b74d54b4 100644 --- a/benchmarks/profile_indexer_topk.py +++ b/benchmarks/profile_indexer_topk.py @@ -1,78 +1,75 @@ -"""Benchmark fused indexer+topk vs reference (full score then jax.lax.top_k). +"""Profile indexer + per-row top-k along T_s. -Production config: B=4, H=16, T_t=T_s=4096, d=128, I=4, d_i=64, k=64, bf16. +Same canonical backends as ``profile_indexer.py`` (reference einsum vs +hybrid einsum+Triton score-reduce), with ``jax.lax.top_k`` applied to the +score matrix. Reports wall time and effective TFLOPS for the indexer +compute (top-k is comparison-only and counted as 0 FLOP). -Sweeps (block_t, block_s, num_warps, num_stages) for the triton kernel and -reports TFLOP/s, ms, and vs-reference speedup. FLOPs counted as the underlying -indexer compute (top-k itself is comparison-only, treated as 0 FLOP). - -Usage: +Run inside the container: docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer_topk.py' """ import time -import functools import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import _indexer_impl_reference, quantize_to_fp8 -from transformer_engine.jax.triton_extensions.indexer import ( - indexer_fused_topk_triton, - indexer_fused_triton, -) -try: - from transformer_engine.jax.pallas_kernels.indexer import indexer_fused as _pallas_indexer - _HAVE_PALLAS = True -except Exception: - _pallas_indexer = None - _HAVE_PALLAS = False - +from transformer_engine.jax.indexer import indexer, quantize_to_fp8 -_FP8_DTYPES = ( - jnp.dtype("float8_e4m3fn"), jnp.dtype("float8_e5m2"), - jnp.dtype("float8_e4m3fnuz"), jnp.dtype("float8_e5m2fnuz"), -) - - -def _is_fp8(dt): - return jnp.dtype(dt) in _FP8_DTYPES - - -def make_inputs(B, H, T_t, T_s, d, I, d_i, dtype, seed=0): - keys = jax.random.split(jax.random.PRNGKey(seed), 5) - Q = jax.random.normal(keys[0], (B, H, T_t, d), dtype=dtype) - K = jax.random.normal(keys[1], (B, H, T_s, d), dtype=dtype) - W_q = jax.random.normal(keys[2], (I, d, d_i), dtype=dtype) - W_k = jax.random.normal(keys[3], (d, d_i), dtype=dtype) - weights = jax.random.normal(keys[4], (B, H, T_t, I), dtype=dtype) - return Q, K, W_q, W_k, weights - - -def make_fp8_inputs(B, H, T_t, T_s, d, I, d_i, *, fp8_dtype, seed=0): - Q, K, W_q, W_k, weights = make_inputs( - B, H, T_t, T_s, d, I, d_i, jnp.bfloat16, seed=seed - ) - Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) - K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) - Wq_q, swq = quantize_to_fp8(W_q, dtype=fp8_dtype) - Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) - return Q_q, K_q, Wq_q, Wk_q, weights, dict( - scale_q=sq, scale_k=sk, scale_wq=swq, scale_wk=swk, +# Triton hybrid backend: einsum projections + Triton score-relu-reduce. +try: + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 + _HAVE_HYBRID = True +except Exception as _e: # noqa: BLE001 + _HAVE_HYBRID = False + _HYBRID_IMPORT_ERROR = _e + + +# --- Inputs / FLOP accounting --------------------------------------------------- +# Mirrors profile_indexer.py — keeping the two profilers in lockstep. + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, W_w + + +def make_fp8_inputs(B, oH, T, S, d, d_c, H, d_i, *, + fp8_dtype=jnp.float8_e4m3fn, weights_dtype=jnp.bfloat16, + seed=0): + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16, seed=seed ) - - -def theoretical_flops(B, H, T_t, T_s, d, I, d_i): - n = B * H + Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) + K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) + Wuq_q, swq = quantize_to_fp8(W_uq, dtype=fp8_dtype) + Wdq_q, swd = quantize_to_fp8(W_dq, dtype=fp8_dtype) + Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) + W_w = W_w.astype(weights_dtype) + scales = dict(scale_q=sq, scale_k=sk, + scale_wq=swq, scale_wd=swd, scale_wk=swk) + return Q_q, K_q, Wuq_q, Wdq_q, Wk_q, W_w, scales + + +def theoretical_flops(B, oH, T, S, d, d_c, H, d_i): + # 2 flops per multiply-add. top-k is comparison-only, counted as 0 FLOP. + n = B * oH return 2 * ( - n * T_t * I * d_i * d - + n * T_s * d_i * d - + n * T_t * I * T_s * d_i - + n * T_t * T_s * I + n * T * d_c * d + + n * T * H * d_i * d_c + + n * S * d_i * d + + n * T * H * S * d_i + + n * T * d * H + + n * T * S * H ) -def time_fn(fn, args, n_warmup=5, n_iter=50): +def time_fn(fn, args, n_warmup=15, n_iter=50): for _ in range(n_warmup): out = fn(*args) jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) @@ -83,165 +80,104 @@ def time_fn(fn, args, n_warmup=5, n_iter=50): return (time.perf_counter() - t0) / n_iter -# Reference, pallas+topk, triton+topk: each accepts an optional `scales` dict -# (None for high-precision). Built fresh per-config since the scales are baked -# into the closure. -def _make_reference_topk(scales): - if scales is None: - @jax.jit - def fn(Q, K, W_q, W_k, weights): - scores = _indexer_impl_reference(Q, K, W_q, W_k, weights) - return jax.lax.top_k(scores, K_TOPK_GLOBAL) - else: - @jax.jit - def fn(Q, K, W_q, W_k, weights): - scores = _indexer_impl_reference(Q, K, W_q, W_k, weights, **scales) - return jax.lax.top_k(scores, K_TOPK_GLOBAL) - return fn +# --- Driver --------------------------------------------------------------------- +CONFIGS = [ + #(B, oH, T, S, d, d_c, H, d_i, dtype) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128, jnp.bfloat16), +] -def _make_pallas_then_topk(scales): - if not _HAVE_PALLAS: - return None - if scales is None: - @jax.jit - def fn(Q, K, W_q, W_k, weights): - scores = _pallas_indexer(Q, K, W_q, W_k, weights) - return jax.lax.top_k(scores, K_TOPK_GLOBAL) - else: - @jax.jit - def fn(Q, K, W_q, W_k, weights): - scores = _pallas_indexer(Q, K, W_q, W_k, weights, **scales) - return jax.lax.top_k(scores, K_TOPK_GLOBAL) - return fn +K_TOPK = 64 -def _make_triton_then_topk(scales): - if scales is None: - @jax.jit - def fn(Q, K, W_q, W_k, weights): - scores = indexer_fused_triton(Q, K, W_q, W_k, weights) - return jax.lax.top_k(scores, K_TOPK_GLOBAL) - else: - @jax.jit - def fn(Q, K, W_q, W_k, weights): - scores = indexer_fused_triton(Q, K, W_q, W_k, weights, **scales) - return jax.lax.top_k(scores, K_TOPK_GLOBAL) - return fn +def _is_fp8(dt): + return jnp.dtype(dt) in ( + jnp.dtype("float8_e4m3fn"), jnp.dtype("float8_e5m2"), + jnp.dtype("float8_e4m3fnuz"), jnp.dtype("float8_e5m2fnuz"), + ) -# Standalone: just time jax.lax.top_k on a precomputed score matrix. -@jax.jit -def topk_only(scores): - return jax.lax.top_k(scores, K_TOPK_GLOBAL) +def _bind_topk(scales, *, backend, k): + """Build a jit'd indexer-then-topk closure for the given backend + scales.""" + extra = {"backend": backend} + if scales is not None: + merged = dict(extra, **scales) + else: + merged = extra + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + scores = indexer(Q, K, W_uq, W_dq, W_k, W_w, **merged) + return jax.lax.top_k(scores, k) -def _make_triton(k, bt, bs, nw, ns): - fn = jax.jit(functools.partial( - indexer_fused_topk_triton, - k=k, block_t=bt, block_s=bs, num_warps=nw, num_stages=ns, - )) return fn -CONFIGS = [ - # (B, H, T_t, T_s, d, I, d_i, dtype) - ( 4, 16, 2048, 2048, 128, 4, 64, jnp.bfloat16), - ( 4, 16, 4096, 4096, 128, 4, 64, jnp.bfloat16), - # FP8 e4m3 — fused-topk Triton kernel doesn't accept FP8 yet; the row will - # report "(skipped: fp8 not supported)" for that impl. The other three - # paths (reference, pallas+topk, triton+topk) all run end-to-end in FP8. - ( 4, 16, 2048, 2048, 128, 4, 64, jnp.float8_e4m3fn), - ( 4, 16, 4096, 4096, 128, 4, 64, jnp.float8_e4m3fn), -] +def _build_impls(scales, k): + impls = [ + ("baseline+topk", _bind_topk(scales, backend="reference", k=k)), + ] + if _HAVE_HYBRID: + impls.append( + ("hybrid+topk", _bind_topk(scales, backend="hybrid", k=k)) + ) + return impls -K_TOPK_GLOBAL = 64 - -SWEEP = [ - # (block_t, block_s, num_warps, num_stages) - ( 64, 64, 4, 1), - ( 64, 64, 8, 1), - (128, 64, 4, 1), - (128, 64, 8, 1), - ( 32, 32, 4, 1), - ( 32, 64, 4, 1), - ( 32, 128, 4, 1), # k=64+128=192 not pow2; will be skipped - ( 64, 32, 4, 1), - (256, 64, 4, 1), - (256, 64, 8, 1), -] + +@jax.jit +def _topk_only(scores): + return jax.lax.top_k(scores, K_TOPK) + + +if not _HAVE_HYBRID: + print(f"[profile_indexer_topk] Hybrid backend unavailable: {_HYBRID_IMPORT_ERROR}") def main(): - print(f"jax devices: {jax.devices()}\nk = {K_TOPK_GLOBAL}\n") + print(f"jax devices: {jax.devices()}\nk = {K_TOPK}\n") for cfg in CONFIGS: - B, H, T_t, T_s, d, I, d_i, dtype = cfg + B, oH, T, S, d, d_c, H, d_i, dtype = cfg is_fp8 = _is_fp8(dtype) if is_fp8: - Q, K, W_q, W_k, weights, scales = make_fp8_inputs( - B, H, T_t, T_s, d, I, d_i, fp8_dtype=dtype + Q, K, W_uq, W_dq, W_k, W_w, scales = make_fp8_inputs( + B, oH, T, S, d, d_c, H, d_i, fp8_dtype=dtype ) - args = (Q, K, W_q, W_k, weights) else: - args = make_inputs(B, H, T_t, T_s, d, I, d_i, dtype) + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, dtype + ) scales = None - flops = theoretical_flops(B, H, T_t, T_s, d, I, d_i) - print(f"--- B={B} H={H} T_t={T_t} T_s={T_s} d={d} I={I} d_i={d_i} {dtype.dtype.name} ---") - print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call") - - impls = [ - ("ref(einsum+topk)", _make_reference_topk(scales)), - ("pallas+topk", _make_pallas_then_topk(scales)), - ("triton+topk", _make_triton_then_topk(scales)), - ] - - ref_ms = None + args = (Q, K, W_uq, W_dq, W_k, W_w) + impls = _build_impls(scales, K_TOPK) + flops = theoretical_flops(B, oH, T, S, d, d_c, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} " + f"{dtype.dtype.name} ---") + print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call (top-k = 0 FLOP)") + baseline_ms = None for name, fn in impls: - if fn is None: - continue try: sec = time_fn(fn, args) ms = sec * 1e3 tflops = flops / sec / 1e12 - if name == "pallas+topk": - ref_ms = ms - print(f" {name:<22} {ms:8.3f} ms {tflops:6.2f} TFLOP/s") + if name == "baseline+topk": + baseline_ms = ms + speed = "" + else: + speed = f" ({baseline_ms/ms:.2f}x baseline)" + print(f" {name:<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") except Exception as e: # noqa: BLE001 - print(f" {name:<22} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print(f" {name:<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") - # Time top_k alone (on pre-materialized scores). For FP8 inputs the - # reference dequantizes internally and returns a high-precision matrix. + # Time top_k alone on a precomputed (reference) score matrix to + # isolate the top-k cost from the indexer compute. try: - if scales is None: - scores_mat = _indexer_impl_reference(*args) - else: - scores_mat = _indexer_impl_reference(*args, **scales) - sec = time_fn(topk_only, (scores_mat,)) - print(f" {'(top_k alone)':<22} {sec*1e3:8.3f} ms") + kw = {"backend": "reference", **(scales or {})} + scores_mat = indexer(*args, **kw) + sec = time_fn(_topk_only, (scores_mat,)) + print(f" {'(top_k alone)':<14} {sec*1e3:8.3f} ms") except Exception as e: # noqa: BLE001 - print(f" (top_k alone) FAILED: {type(e).__name__}") - - # Fused-topk Triton kernel does not accept FP8 yet — skip the sweep - # for FP8 configs. - if is_fp8: - print(f" {'fused-topk triton':<22} (skipped: fp8 not supported by topk kernel)") - print() - continue - - # Triton fused-topk sweep (high-precision only) - for bt, bs, nw, ns in SWEEP: - if (K_TOPK_GLOBAL + bs) & (K_TOPK_GLOBAL + bs - 1) != 0: - continue # k+block_s must be pow2 - label = f"triton bt={bt} bs={bs} W={nw} S={ns}" - try: - fn = _make_triton(K_TOPK_GLOBAL, bt, bs, nw, ns) - sec = time_fn(fn, args) - ms = sec * 1e3 - tflops = flops / sec / 1e12 - speed = f" ({ref_ms/ms:.2f}x ref)" if ref_ms else "" - print(f" {label:<22} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") - except Exception as e: # noqa: BLE001 - print(f" {label:<22} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print(f" (top_k alone) FAILED: {type(e).__name__}") print() diff --git a/benchmarks/run_indexer_kernel.py b/benchmarks/run_indexer_kernel.py deleted file mode 100644 index d61a1738c..000000000 --- a/benchmarks/run_indexer_kernel.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Minimal direct invocation of the low-rank indexer kernel for profiling. - -No baselines, no comparisons. Just: build inputs once, jit the kernel, -warm it up, then run a fixed number of iterations under whatever -profiler is wrapping this process. - -Run inside the container: - docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/run_indexer_kernel.py' -""" - -import argparse -import time - -import jax -import jax.numpy as jnp - -from transformer_engine.jax.indexer import quantize_to_fp8 -from transformer_engine.jax.triton_extensions.indexer import indexer_fused_triton as _triton_indexer - -_BACKENDS = { - "triton": _triton_indexer, -} - -_DTYPE_MAP = { - "bf16": jnp.bfloat16, - "fp32": jnp.float32, - "fp8": jnp.float8_e4m3fn, -} - - -def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): - keys = jax.random.split(jax.random.PRNGKey(seed), 6) - Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) - K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) - W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) - W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) - W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) - weights = jax.random.normal(keys[5], (B, oH, H, T), dtype=dtype) - return Q, K, W_uq, W_dq, W_k, weights - - -def make_fp8_inputs(B, oH, T, S, d, d_c, H, d_i, *, fp8_dtype, seed=0): - """Quantize all five matrices to FP8; weights stay bf16.""" - Q, K, W_uq, W_dq, W_k, weights = make_inputs( - B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16, seed=seed - ) - Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) - K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) - Wuq_q, swq = quantize_to_fp8(W_uq, dtype=fp8_dtype) - Wdq_q, swd = quantize_to_fp8(W_dq, dtype=fp8_dtype) - Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) - return Q_q, K_q, Wuq_q, Wdq_q, Wk_q, weights, dict( - scale_q=sq, scale_k=sk, scale_wq=swq, scale_wd=swd, scale_wk=swk, - ) - - -def main(): - p = argparse.ArgumentParser() - p.add_argument("--B", type=int, default=4) - p.add_argument("--oH", type=int, default=16, help="outer (multi-attn) heads") - p.add_argument("--T", type=int, default=2048) - p.add_argument("--S", type=int, default=2048) - p.add_argument("--d", type=int, default=512, help="hidden dim") - p.add_argument("--d_c", type=int, default=128, help="down-projection rank") - p.add_argument("--H", type=int, default=64, help="indexer-head count") - p.add_argument("--d_i", type=int, default=128, help="per-indexer-head dim") - p.add_argument("--dtype", choices=list(_DTYPE_MAP), default="bf16") - p.add_argument("--warmup", type=int, default=5) - p.add_argument("--iters", type=int, default=50) - p.add_argument("--backend", choices=list(_BACKENDS), default="triton") - args = p.parse_args() - - dtype = _DTYPE_MAP[args.dtype] - is_fp8 = args.dtype == "fp8" - print(f"jax devices: {jax.devices()}") - print(f"shape: B={args.B} oH={args.oH} T={args.T} S={args.S} " - f"d={args.d} d_c={args.d_c} H={args.H} d_i={args.d_i} " - f"dtype={args.dtype} backend={args.backend}") - - if is_fp8: - Q, K, W_uq, W_dq, W_k, weights, scales = make_fp8_inputs( - args.B, args.oH, args.T, args.S, - args.d, args.d_c, args.H, args.d_i, fp8_dtype=dtype, - ) - inputs = (Q, K, W_uq, W_dq, W_k, weights) - else: - inputs = make_inputs(args.B, args.oH, args.T, args.S, - args.d, args.d_c, args.H, args.d_i, dtype) - scales = None - - raw_fn = _BACKENDS[args.backend] - - if scales is None: - @jax.jit - def fn(Q, K, W_uq, W_dq, W_k, weights): - return raw_fn(Q, K, W_uq, W_dq, W_k, weights) - else: - @jax.jit - def fn(Q, K, W_uq, W_dq, W_k, weights): - return raw_fn(Q, K, W_uq, W_dq, W_k, weights, **scales) - - # Warmup: triggers JIT compile + first-launch overhead. - for _ in range(args.warmup): - out = fn(*inputs) - jax.block_until_ready(out) - - # Timed region: this is what the profiler should focus on. - t0 = time.perf_counter() - for _ in range(args.iters): - out = fn(*inputs) - jax.block_until_ready(out) - sec = (time.perf_counter() - t0) / args.iters - print(f"avg per call: {sec*1e3:.3f} ms ({args.iters} iters)") - - -if __name__ == "__main__": - main() From 358d3262b041c11a41e721e58ac07393936dff47 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 12 May 2026 21:48:11 +0000 Subject: [PATCH 05/17] Trimmed fp8 fragments --- benchmarks/profile_indexer.py | 102 ++++------------- benchmarks/profile_indexer_topk.py | 86 +++------------ transformer_engine/jax/indexer.py | 171 +++++------------------------ 3 files changed, 65 insertions(+), 294 deletions(-) diff --git a/benchmarks/profile_indexer.py b/benchmarks/profile_indexer.py index 8786be47c..e31d47a1d 100644 --- a/benchmarks/profile_indexer.py +++ b/benchmarks/profile_indexer.py @@ -1,4 +1,4 @@ -"""Profile the low-rank lightning-indexer at realistic shapes. +"""Profile the low-rank lightning-indexer at realistic shapes (bf16). Measures wall time and effective TFLOPS for the einsum baseline vs the fused Triton kernel. @@ -12,9 +12,8 @@ import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import indexer, quantize_to_fp8 +from transformer_engine.jax.indexer import indexer -# Triton hybrid backend: einsum projections + Triton score-relu-reduce. try: from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 _HAVE_HYBRID = True @@ -32,35 +31,10 @@ def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) - # Learnable per-(token, indexer-head) weight projection: W_o = Q @ W_w. W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) return Q, K, W_uq, W_dq, W_k, W_w -def make_fp8_inputs(B, oH, T, S, d, d_c, H, d_i, *, - fp8_dtype=jnp.float8_e4m3fn, weights_dtype=jnp.bfloat16, - seed=0): - """Sample bf16 tensors then quantize Q/K/W_uq/W_dq/W_k to FP8. - - W_w stays in ``weights_dtype`` (bf16) — the reference impl does not - dequantize it. - - Returns (Q, K, W_uq, W_dq, W_k, W_w, scales_dict). - """ - Q, K, W_uq, W_dq, W_k, W_w = make_inputs( - B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16, seed=seed - ) - Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) - K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) - Wuq_q, swq = quantize_to_fp8(W_uq, dtype=fp8_dtype) - Wdq_q, swd = quantize_to_fp8(W_dq, dtype=fp8_dtype) - Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) - W_w = W_w.astype(weights_dtype) - scales = dict(scale_q=sq, scale_k=sk, - scale_wq=swq, scale_wd=swd, scale_wk=swk) - return Q_q, K_q, Wuq_q, Wdq_q, Wk_q, W_w, scales - - def theoretical_flops(B, oH, T, S, d, d_c, H, d_i): # 2 flops per multiply-add. Counts the contractions in the low-rank # indexer with learnable output-weight projection: @@ -95,47 +69,16 @@ def time_fn(fn, args, n_warmup=15, n_iter=50): # --- Driver --------------------------------------------------------------------- CONFIGS = [ - #(B, oH, T, S, d, d_c, H, d_i, dtype) - ( 2, 64, 1024, 1024, 512, 1024, 64, 128, jnp.bfloat16), + #(B, oH, T, S, d, d_c, H, d_i) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128), ] -def _is_fp8(dt): - return jnp.dtype(dt) in ( - jnp.dtype("float8_e4m3fn"), jnp.dtype("float8_e5m2"), - jnp.dtype("float8_e4m3fnuz"), jnp.dtype("float8_e5m2fnuz"), - ) - - -def _bind_scales(fn, scales, *, backend=None): - """Return a 6-arg jit-able function that internally adds scale kwargs.""" - extra = {} - if backend is not None: - extra["backend"] = backend - if scales is None and not extra: - return jax.jit(fn) +def _build_impl(backend): @jax.jit - def wrapped(Q, K, W_uq, W_dq, W_k, W_w): - kwargs = dict(extra) - if scales is not None: - kwargs.update(scales) - return fn(Q, K, W_uq, W_dq, W_k, W_w, **kwargs) - return wrapped - - -def _build_impls(scales): - impls = [ - ("baseline", _bind_scales(indexer, scales, backend="reference")), - ] - if _HAVE_HYBRID: - impls.append( - ("hybrid", _bind_scales(indexer, scales, backend="hybrid")) - ) - return impls - - -if not _HAVE_HYBRID: - print(f"[profile_indexer] Hybrid backend unavailable: {_HYBRID_IMPORT_ERROR}") + def fn(Q, K, W_uq, W_dq, W_k, W_w): + return indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) + return fn def _dump_autotuner_winner(): @@ -156,27 +99,26 @@ def _dump_autotuner_winner(): print(f" [autotune] key={key} -> {cfg}") +if not _HAVE_HYBRID: + print(f"[profile_indexer] Hybrid backend unavailable: {_HYBRID_IMPORT_ERROR}") + + def main(): print(f"jax devices: {jax.devices()}\n") - for cfg in CONFIGS: - B, oH, T, S, d, d_c, H, d_i, dtype = cfg - is_fp8 = _is_fp8(dtype) - if is_fp8: - Q, K, W_uq, W_dq, W_k, W_w, scales = make_fp8_inputs( - B, oH, T, S, d, d_c, H, d_i, fp8_dtype=dtype - ) - else: - Q, K, W_uq, W_dq, W_k, W_w = make_inputs( - B, oH, T, S, d, d_c, H, d_i, dtype - ) - scales = None + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 + ) args = (Q, K, W_uq, W_dq, W_k, W_w) - impls = _build_impls(scales) flops = theoretical_flops(B, oH, T, S, d, d_c, H, d_i) - print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} " - f"{dtype.dtype.name} ---") + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call") + + impls = [("baseline", _build_impl("reference"))] + if _HAVE_HYBRID: + impls.append(("hybrid", _build_impl("hybrid"))) + baseline_ms = None for name, fn in impls: try: diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py index 1b74d54b4..2040b86e4 100644 --- a/benchmarks/profile_indexer_topk.py +++ b/benchmarks/profile_indexer_topk.py @@ -1,4 +1,4 @@ -"""Profile indexer + per-row top-k along T_s. +"""Profile indexer + per-row top-k along T_s (bf16). Same canonical backends as ``profile_indexer.py`` (reference einsum vs hybrid einsum+Triton score-reduce), with ``jax.lax.top_k`` applied to the @@ -14,9 +14,8 @@ import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import indexer, quantize_to_fp8 +from transformer_engine.jax.indexer import indexer -# Triton hybrid backend: einsum projections + Triton score-relu-reduce. try: from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 _HAVE_HYBRID = True @@ -39,23 +38,6 @@ def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): return Q, K, W_uq, W_dq, W_k, W_w -def make_fp8_inputs(B, oH, T, S, d, d_c, H, d_i, *, - fp8_dtype=jnp.float8_e4m3fn, weights_dtype=jnp.bfloat16, - seed=0): - Q, K, W_uq, W_dq, W_k, W_w = make_inputs( - B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16, seed=seed - ) - Q_q, sq = quantize_to_fp8(Q, dtype=fp8_dtype) - K_q, sk = quantize_to_fp8(K, dtype=fp8_dtype) - Wuq_q, swq = quantize_to_fp8(W_uq, dtype=fp8_dtype) - Wdq_q, swd = quantize_to_fp8(W_dq, dtype=fp8_dtype) - Wk_q, swk = quantize_to_fp8(W_k, dtype=fp8_dtype) - W_w = W_w.astype(weights_dtype) - scales = dict(scale_q=sq, scale_k=sk, - scale_wq=swq, scale_wd=swd, scale_wk=swk) - return Q_q, K_q, Wuq_q, Wdq_q, Wk_q, W_w, scales - - def theoretical_flops(B, oH, T, S, d, d_c, H, d_i): # 2 flops per multiply-add. top-k is comparison-only, counted as 0 FLOP. n = B * oH @@ -83,47 +65,21 @@ def time_fn(fn, args, n_warmup=15, n_iter=50): # --- Driver --------------------------------------------------------------------- CONFIGS = [ - #(B, oH, T, S, d, d_c, H, d_i, dtype) - ( 2, 64, 1024, 1024, 512, 1024, 64, 128, jnp.bfloat16), + #(B, oH, T, S, d, d_c, H, d_i) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128), ] -K_TOPK = 64 - - -def _is_fp8(dt): - return jnp.dtype(dt) in ( - jnp.dtype("float8_e4m3fn"), jnp.dtype("float8_e5m2"), - jnp.dtype("float8_e4m3fnuz"), jnp.dtype("float8_e5m2fnuz"), - ) - +K_TOPK = 512 -def _bind_topk(scales, *, backend, k): - """Build a jit'd indexer-then-topk closure for the given backend + scales.""" - extra = {"backend": backend} - if scales is not None: - merged = dict(extra, **scales) - else: - merged = extra +def _build_topk(backend, k): @jax.jit def fn(Q, K, W_uq, W_dq, W_k, W_w): - scores = indexer(Q, K, W_uq, W_dq, W_k, W_w, **merged) + scores = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) return jax.lax.top_k(scores, k) - return fn -def _build_impls(scales, k): - impls = [ - ("baseline+topk", _bind_topk(scales, backend="reference", k=k)), - ] - if _HAVE_HYBRID: - impls.append( - ("hybrid+topk", _bind_topk(scales, backend="hybrid", k=k)) - ) - return impls - - @jax.jit def _topk_only(scores): return jax.lax.top_k(scores, K_TOPK) @@ -135,25 +91,20 @@ def _topk_only(scores): def main(): print(f"jax devices: {jax.devices()}\nk = {K_TOPK}\n") - for cfg in CONFIGS: - B, oH, T, S, d, d_c, H, d_i, dtype = cfg - is_fp8 = _is_fp8(dtype) - if is_fp8: - Q, K, W_uq, W_dq, W_k, W_w, scales = make_fp8_inputs( - B, oH, T, S, d, d_c, H, d_i, fp8_dtype=dtype - ) - else: - Q, K, W_uq, W_dq, W_k, W_w = make_inputs( - B, oH, T, S, d, d_c, H, d_i, dtype - ) - scales = None + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 + ) args = (Q, K, W_uq, W_dq, W_k, W_w) - impls = _build_impls(scales, K_TOPK) flops = theoretical_flops(B, oH, T, S, d, d_c, H, d_i) - print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} " - f"{dtype.dtype.name} ---") + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call (top-k = 0 FLOP)") + + impls = [("baseline+topk", _build_topk("reference", K_TOPK))] + if _HAVE_HYBRID: + impls.append(("hybrid+topk", _build_topk("hybrid", K_TOPK))) + baseline_ms = None for name, fn in impls: try: @@ -172,8 +123,7 @@ def main(): # Time top_k alone on a precomputed (reference) score matrix to # isolate the top-k cost from the indexer compute. try: - kw = {"backend": "reference", **(scales or {})} - scores_mat = indexer(*args, **kw) + scores_mat = indexer(*args, backend="reference") sec = time_fn(_topk_only, (scores_mat,)) print(f" {'(top_k alone)':<14} {sec*1e3:8.3f} ms") except Exception as e: # noqa: BLE001 diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/indexer.py index d6021f447..bdfea9fe0 100644 --- a/transformer_engine/jax/indexer.py +++ b/transformer_engine/jax/indexer.py @@ -1,4 +1,4 @@ -"""Indexer op (forward only). +"""Indexer op (forward only), bf16 inputs. Two canonical backends: * ``"reference"`` — pure ``jnp.einsum``. Materializes the @@ -20,11 +20,6 @@ W_o = Q @ W_w # (..., T, H) H = relu(einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) O = einsum("...ths,...th->...ts", H, W_o) # (..., T, S) - -FP8 mode: any of Q / K / W_uq / W_dq / W_k may be FP8 (e4m3) tensors with -per-tensor fp32 scales. They are dequantized to bf16 inside both backends -before the projections — XLA on ROCm has no fp8 GEMM rewriter, and the -hybrid kernel itself runs in bf16. """ import functools @@ -33,41 +28,7 @@ import jax.numpy as jnp -_FP8_DTYPES = frozenset([ - jnp.dtype("float8_e4m3fn"), - jnp.dtype("float8_e5m2"), - jnp.dtype("float8_e4m3fnuz"), - jnp.dtype("float8_e5m2fnuz"), -]) - - -def _is_fp8(x): - return jnp.dtype(x.dtype) in _FP8_DTYPES - - -def quantize_to_fp8(x, *, dtype=None, axis=None): - """Per-tensor amax-based quantization helper (for tests/profiling). - - Returns (x_fp8, scale_fp32) where the dequantization is ``x_fp8 * scale``. - """ - if dtype is None: - dtype = jnp.float8_e4m3fn - fp8_max = jnp.finfo(dtype).max.astype(jnp.float32) - amax = jnp.max(jnp.abs(x.astype(jnp.float32))) if axis is None else \ - jnp.max(jnp.abs(x.astype(jnp.float32)), axis=axis, keepdims=True) - scale = (amax / fp8_max).astype(jnp.float32) - # avoid divide-by-zero on all-zero tensors - scale = jnp.where(scale == 0, jnp.float32(1.0), scale) - x_fp8 = (x.astype(jnp.float32) / scale).astype(dtype) - return x_fp8, scale - - -# --- Reference implementation --------------------------------------------------- - -def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, - scale_q=None, scale_k=None, - scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None): +def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): """ Q [..., T, d] K [..., S, d] @@ -75,55 +36,19 @@ def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, W_uq [H, d_c, d_i] W_k [d, d_i] W_w [..., d, H] # leading dims must match Q's - - FP8 path: each fp8 operand is dequantized via cast-to-bf16-then-multiply - immediately before the matmul that consumes it. This is the pattern XLA's - GEMM rewriter recognizes and lowers to ``__cublas$lt$matmul$f8`` (native - fp8 hardware GEMM) for matmuls where both operands are originally fp8. - Upcasting to fp32 first would lose the fp8 type info and fall back to - plain fp32 GEMM — strictly worse. """ - if _is_fp8(Q): - if any(s is None for s in (scale_q, scale_k, scale_wq, scale_wk)): - raise ValueError( - "FP8 reference requires scale_q, scale_k, scale_wq, scale_wk." - ) - if _is_fp8(W_dq) and scale_wd is None: - raise ValueError("FP8 W_dq requires scale_wd.") - - wp = jnp.bfloat16 # working precision for non-fp8 intermediates - - def _dq(x, s): - # cast-then-scale pattern (in working precision, NOT fp32). XLA's - # GEMM rewriter pulls (cast, multiply, dot) into a fused fp8 GEMM - # when both operands of the dot follow this pattern. - if _is_fp8(x): - return x.astype(wp) * jnp.float32(s).astype(wp) - return x.astype(wp) - - Q_d = _dq(Q, scale_q) - K_d = _dq(K, scale_k) - W_uq_d = _dq(W_uq, scale_wq) - W_dq_d = _dq(W_dq, scale_wd) - W_k_d = _dq(W_k, scale_wk) - - C_q = jnp.einsum("...td,dc->...tc", Q_d, W_dq_d) # (..., T, d_c) - H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq_d) # (..., T, H, d_i) - H_k = jnp.einsum("...sd,di->...si", K_d, W_k_d) # (..., S, d_i) - H = jax.nn.relu(jnp.einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) - W_o = jnp.einsum("...td,dh->...th", Q_d, W_w) - O = jnp.einsum("...ths,...th->...ts", H, W_o) # (..., T, S) + C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) # (..., T, d_c) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) + H_k = jnp.einsum("...sd,di->...si", K, W_k) # (..., S, d_i) + H = jax.nn.relu(jnp.einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) + W_o = jnp.einsum("...td,dh->...th", Q, W_w) + O = jnp.einsum("...ths,...th->...ts", H, W_o) # (..., T, S) if out_dtype is not None: O = O.astype(out_dtype) return O -# --- Hybrid implementation (einsum projections + Triton score-reduce) --------- - -def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, - scale_q=None, scale_k=None, - scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None): +def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): """Einsum projections + Triton score-relu-reduce. Mirrors ``_indexer_impl_reference`` for the four projections (which @@ -134,45 +59,18 @@ def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, """ from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton - if _is_fp8(Q): - if any(s is None for s in (scale_q, scale_k, scale_wq, scale_wk)): - raise ValueError( - "FP8 hybrid requires scale_q, scale_k, scale_wq, scale_wk." - ) - if _is_fp8(W_dq) and scale_wd is None: - raise ValueError("FP8 W_dq requires scale_wd.") - - wp = jnp.bfloat16 - - def _dq(x, s): - if _is_fp8(x): - return x.astype(wp) * jnp.float32(s).astype(wp) - return x.astype(wp) - - Q_d = _dq(Q, scale_q) - K_d = _dq(K, scale_k) - W_uq_d = _dq(W_uq, scale_wq) - W_dq_d = _dq(W_dq, scale_wd) - W_k_d = _dq(W_k, scale_wk) - - C_q = jnp.einsum("...td,dc->...tc", Q_d, W_dq_d) # (..., T, d_c) - H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq_d) # (..., T, H, d_i) - H_k = jnp.einsum("...sd,di->...si", K_d, W_k_d) # (..., S, d_i) - W_o = jnp.einsum("...td,dh->...th", Q_d, W_w.astype(wp)) # (..., T, H) - - O = score_reduce_triton(H_q, H_k, W_o, - out_dtype=out_dtype if out_dtype else wp) - return O + C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) # (..., T, d_c) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) + H_k = jnp.einsum("...sd,di->...si", K, W_k) # (..., S, d_i) + W_o = jnp.einsum("...td,dh->...th", Q, W_w) # (..., T, H) + return score_reduce_triton(H_q, H_k, W_o, + out_dtype=out_dtype if out_dtype else Q.dtype) -# --- Top-level dispatch --------------------------------------------------------- @functools.partial(jax.jit, static_argnames=("backend", "out_dtype")) -def indexer(Q, K, W_uq, W_dq, W_k, weights, *, - scale_q=None, scale_k=None, - scale_wq=None, scale_wd=None, scale_wk=None, - out_dtype=None, backend="reference"): - """Low-rank lightning-indexer. +def indexer(Q, K, W_uq, W_dq, W_k, weights, *, out_dtype=None, backend="reference"): + """Low-rank lightning-indexer (bf16). Args: Q: (..., T, d) hidden state (per token) @@ -182,28 +80,17 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, W_k: (d, d_i) key projection weights: (d, H) learnable output-weight projection (W_o = Q @ weights inside the impl) - scale_q, scale_k, scale_wq, scale_wk: - per-tensor fp32 dequant scales. Required when Q is FP8. - scale_wd: - per-tensor fp32 dequant scale for W_dq. Required only when - W_dq itself is FP8. - out_dtype: output dtype override (defaults to Q.dtype, or bf16 for - the hybrid backend). + out_dtype: output dtype override (defaults to Q.dtype). backend: "reference" (pure einsum) or "hybrid" (einsum projections + Triton score-relu-reduce kernel). Returns: O of shape (..., T, S). """ - fp8_kwargs = dict( - scale_q=scale_q, scale_k=scale_k, - scale_wq=scale_wq, scale_wd=scale_wd, scale_wk=scale_wk, - out_dtype=out_dtype, - ) if backend == "reference": - return _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + return _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights, out_dtype=out_dtype) if backend == "hybrid": - return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, **fp8_kwargs) + return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, out_dtype=out_dtype) raise ValueError( f"unknown backend {backend!r}; expected 'reference' or 'hybrid'" ) @@ -212,8 +99,7 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, # --- Tests ---------------------------------------------------------------------- def _run_test(leading_shape, seed, backend): - # The hybrid backend's Triton kernel requires rank-4 BHSD inputs, so this - # smoke test only exercises that shape (and the reference-vs-hybrid agreement). + # The hybrid backend's Triton kernel requires rank-4 BHSD inputs. T_t, T_s, d, d_c, H, d_i = 64, 64, 32, 32, 8, 32 keys = jax.random.split(jax.random.PRNGKey(seed), 6) Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) @@ -223,13 +109,8 @@ def _run_test(leading_shape, seed, backend): W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) - try: - O_ref = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") - O_b = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) - except Exception as e: # noqa: BLE001 - print(f" backend={backend:<10s} leading={str(leading_shape):10s} " - f"SKIP: {type(e).__name__}: {str(e).splitlines()[0]}") - return + O_ref = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") + O_b = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) diff = (O_ref.astype(jnp.float32) - O_b.astype(jnp.float32)) rel_err = float(jnp.linalg.norm(diff) / @@ -241,9 +122,7 @@ def _run_test(leading_shape, seed, backend): if __name__ == "__main__": print("=== reference vs reference (sanity) ===") - for i, leading in enumerate([(2, 3),]): - _run_test(leading, seed=i, backend="reference") + _run_test((2, 3), seed=0, backend="reference") print("\n=== hybrid vs reference ===") - for i, leading in enumerate([(2, 3),]): - _run_test(leading, seed=100 + i, backend="hybrid") + _run_test((2, 3), seed=100, backend="hybrid") From 7c552550dd4a0b8fff9c32174a1ba5fb11c03dd9 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 14 May 2026 15:14:02 +0000 Subject: [PATCH 06/17] Added initial bwd kernels --- benchmarks/profile_indexer.py | 9 +- benchmarks/profile_indexer_bwd.py | 135 +++++ benchmarks/profile_indexer_topk.py | 27 +- transformer_engine/jax/indexer.py | 96 +++ .../jax/triton_extensions/__init__.py | 2 +- .../jax/triton_extensions/indexer.py | 561 +++++++++++++++++- 6 files changed, 813 insertions(+), 17 deletions(-) create mode 100644 benchmarks/profile_indexer_bwd.py diff --git a/benchmarks/profile_indexer.py b/benchmarks/profile_indexer.py index e31d47a1d..481987931 100644 --- a/benchmarks/profile_indexer.py +++ b/benchmarks/profile_indexer.py @@ -70,7 +70,7 @@ def time_fn(fn, args, n_warmup=15, n_iter=50): CONFIGS = [ #(B, oH, T, S, d, d_c, H, d_i) - ( 2, 64, 1024, 1024, 512, 1024, 64, 128), + ( 2, 64, 4096, 4096, 512, 1024, 64, 128), ] @@ -115,7 +115,8 @@ def main(): print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call") - impls = [("baseline", _build_impl("reference"))] + # impls = [("baseline", _build_impl("reference"))] + impls = [] if _HAVE_HYBRID: impls.append(("hybrid", _build_impl("hybrid"))) @@ -128,8 +129,10 @@ def main(): if name == "baseline": baseline_ms = ms speed = "" - else: + elif baseline_ms is not None: speed = f" ({baseline_ms/ms:.2f}x baseline)" + else: + speed = "" print(f" {name:<10} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") except Exception as e: # noqa: BLE001 print(f" {name:<10} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") diff --git a/benchmarks/profile_indexer_bwd.py b/benchmarks/profile_indexer_bwd.py new file mode 100644 index 000000000..070c6a7aa --- /dev/null +++ b/benchmarks/profile_indexer_bwd.py @@ -0,0 +1,135 @@ +"""Profile lightning-indexer backward pass throughput (bf16). + +Measures wall time and effective TFLOPS for forward, backward, and +value_and_grad. Uses the standard "backward = 2x forward FLOPs" convention, +so value_and_grad total work = 3x forward FLOPs. + +Run inside the container: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer_bwd.py' +""" + +import time + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.indexer import indexer + +try: + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 + _HAVE_HYBRID = True +except Exception as _e: # noqa: BLE001 + _HAVE_HYBRID = False + _HYBRID_IMPORT_ERROR = _e + + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, W_w + + +def theoretical_fwd_flops(B, oH, T, S, d, d_c, H, d_i): + n = B * oH + return 2 * ( + n * T * d_c * d + + n * T * H * d_i * d_c + + n * S * d_i * d + + n * T * H * S * d_i + + n * T * d * H + + n * T * S * H + ) + + +def time_fn(fn, args, n_warmup=10, n_iter=30): + for _ in range(n_warmup): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + t0 = time.perf_counter() + for _ in range(n_iter): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + return (time.perf_counter() - t0) / n_iter + + +CONFIGS = [ + #(B, oH, T, S, d, d_c, H, d_i) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128), +] + + +def _build_fwd(backend): + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + O = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) + return jnp.sum(O.astype(jnp.float32)) + return fn + + +def _build_bwd(backend): + """Backward only: returns gradients.""" + fwd = _build_fwd(backend) + return jax.jit(jax.grad(fwd, argnums=(0, 1, 2, 3, 4, 5))) + + +def _build_value_and_grad(backend): + fwd = _build_fwd(backend) + return jax.jit(jax.value_and_grad(fwd, argnums=(0, 1, 2, 3, 4, 5))) + + +def main(): + print(f"jax devices: {jax.devices()}\n") + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 + ) + args = (Q, K, W_uq, W_dq, W_k, W_w) + fwd_flops = theoretical_fwd_flops(B, oH, T, S, d, d_c, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") + print(f" forward GFLOPs/call: {fwd_flops/1e9:.2f}") + print(f" bwd GFLOPs/call (~2x): {2*fwd_flops/1e9:.2f}") + print(f" f+b GFLOPs/call (~3x): {3*fwd_flops/1e9:.2f}") + print() + + backends = ["reference"] + if _HAVE_HYBRID: + backends.append("hybrid") + + # Headers + print(f" {'backend':<10s} {'pass':<14s} {'ms':>8s} {'TFLOP/s':>8s}") + + for backend in backends: + try: + # Forward (loss only) + fwd = _build_fwd(backend) + sec = time_fn(fwd, args) + ms = sec * 1e3 + tflops = fwd_flops / sec / 1e12 + print(f" {backend:<10s} {'forward':<14s} {ms:8.3f} {tflops:8.2f}") + + # Backward only (jax.grad — XLA may re-trace forward inside) + bwd = _build_bwd(backend) + sec = time_fn(bwd, args) + ms = sec * 1e3 + tflops = 2 * fwd_flops / sec / 1e12 # bwd ~= 2x fwd + print(f" {backend:<10s} {'backward':<14s} {ms:8.3f} {tflops:8.2f}") + + # value_and_grad (forward + backward, single pass) + vag = _build_value_and_grad(backend) + sec = time_fn(vag, args) + ms = sec * 1e3 + tflops = 3 * fwd_flops / sec / 1e12 # f+b ~= 3x fwd + print(f" {backend:<10s} {'value_and_grad':<14s} {ms:8.3f} {tflops:8.2f}") + except Exception as e: # noqa: BLE001 + print(f" {backend:<10s} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py index 2040b86e4..82f39b177 100644 --- a/benchmarks/profile_indexer_topk.py +++ b/benchmarks/profile_indexer_topk.py @@ -14,7 +14,7 @@ import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import indexer +from transformer_engine.jax.indexer import indexer, indexer_topk try: from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 @@ -66,7 +66,7 @@ def time_fn(fn, args, n_warmup=15, n_iter=50): CONFIGS = [ #(B, oH, T, S, d, d_c, H, d_i) - ( 2, 64, 1024, 1024, 512, 1024, 64, 128), + ( 2, 64, 4096, 4096, 512, 1024, 64, 128), ] K_TOPK = 512 @@ -80,6 +80,13 @@ def fn(Q, K, W_uq, W_dq, W_k, W_w): return fn +def _build_fused_topk(k): + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + return indexer_topk(Q, K, W_uq, W_dq, W_k, W_w, k=k) + return fn + + @jax.jit def _topk_only(scores): return jax.lax.top_k(scores, K_TOPK) @@ -101,9 +108,11 @@ def main(): print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call (top-k = 0 FLOP)") - impls = [("baseline+topk", _build_topk("reference", K_TOPK))] + # impls = [("baseline+topk", _build_topk("reference", K_TOPK))] + impls = [] if _HAVE_HYBRID: impls.append(("hybrid+topk", _build_topk("hybrid", K_TOPK))) + impls.append(("hybrid_fused_topk", _build_fused_topk(K_TOPK))) baseline_ms = None for name, fn in impls: @@ -114,20 +123,22 @@ def main(): if name == "baseline+topk": baseline_ms = ms speed = "" - else: + elif baseline_ms is not None: speed = f" ({baseline_ms/ms:.2f}x baseline)" - print(f" {name:<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") + else: + speed = "" + print(f" {name:<18} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") except Exception as e: # noqa: BLE001 - print(f" {name:<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print(f" {name:<18} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") # Time top_k alone on a precomputed (reference) score matrix to # isolate the top-k cost from the indexer compute. try: scores_mat = indexer(*args, backend="reference") sec = time_fn(_topk_only, (scores_mat,)) - print(f" {'(top_k alone)':<14} {sec*1e3:8.3f} ms") + print(f" {'(top_k alone)':<18} {sec*1e3:8.3f} ms") except Exception as e: # noqa: BLE001 - print(f" (top_k alone) FAILED: {type(e).__name__}") + print(f" {'(top_k alone) FAILED':<18} {type(e).__name__}") print() diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/indexer.py index bdfea9fe0..347e5ead6 100644 --- a/transformer_engine/jax/indexer.py +++ b/transformer_engine/jax/indexer.py @@ -68,6 +68,32 @@ def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): out_dtype=out_dtype if out_dtype else Q.dtype) +@functools.partial(jax.jit, static_argnames=("k",)) +def indexer_topk(Q, K, W_uq, W_dq, W_k, weights, *, k): + """Lightning-indexer + top-k (fused). + + Same projections as ``indexer()`` (reference math), then a single Triton + kernel that computes the score row, ReLU, weighted H-reduction, and + streaming top-k all in one pass — the (B, oH, T_t, T_s) score matrix is + never materialized. + + Args: + Q, K, W_uq, W_dq, W_k, weights: same as ``indexer()``. + k: number of top scores to return per (B, oH, T_t) row. + Must be a power of 2 and <= S. + + Returns: + Topk_idx: (..., T_t, k) int32 — top-k indices into the S axis, + in descending score order. + """ + from transformer_engine.jax.triton_extensions.indexer import score_topk_triton + C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) # (..., T, d_c) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) + H_k = jnp.einsum("...sd,di->...si", K, W_k) # (..., S, d_i) + W_o = jnp.einsum("...td,dh->...th", Q, weights) # (..., T, H) + return score_topk_triton(H_q, H_k, W_o, k=k) + + @functools.partial(jax.jit, static_argnames=("backend", "out_dtype")) def indexer(Q, K, W_uq, W_dq, W_k, weights, *, out_dtype=None, backend="reference"): """Low-rank lightning-indexer (bf16). @@ -120,9 +146,79 @@ def _run_test(leading_shape, seed, backend): f"O.shape={O_b.shape} rel.err={rel_err:.2e} [{tag}]") +def _run_topk_test(leading_shape, seed, k): + # H=16 to keep the matmul in [BLOCK_S, H] friendly to MFMA tile sizes. + T_t, T_s, d, d_c, H, d_i = 64, 128, 32, 32, 16, 32 + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (*leading_shape, T_s, d), dtype=jnp.bfloat16) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) + W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) + + O_ref = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") + topk_fused = indexer_topk(Q, K, W_uq, W_dq, W_k, W_w, k=k) + + # Correctness check: the scores at the fused-picked indices should equal the + # top-k scores from the reference (within bf16 noise). Set-equality of indices + # is too strict — different backends break ties differently. + O_ref32 = O_ref.astype(jnp.float32) + ref_topk_vals = jax.lax.top_k(O_ref32, k=k)[0] # [..., T_t, k] sorted desc + fused_picked_vals = jnp.take_along_axis(O_ref32, topk_fused, axis=-1) + fused_picked_sorted = jnp.sort(fused_picked_vals, axis=-1)[..., ::-1] + rel_diff = jnp.abs(ref_topk_vals - fused_picked_sorted) / (jnp.abs(ref_topk_vals) + 1e-6) + max_rel = float(rel_diff.max()) + tag = "OK" if max_rel < 1e-2 else f"FAIL (max_rel={max_rel:.2e})" + print(f" topk leading={str(leading_shape):10s} k={k:<4d} " + f"out.shape={topk_fused.shape} max_rel={max_rel:.2e} [{tag}]") + + +def _run_bwd_test(leading_shape, seed): + """Compare hybrid backward against jax.grad on the reference impl.""" + T_t, T_s, d, d_c, H, d_i = 32, 32, 32, 32, 8, 32 + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (*leading_shape, T_s, d), dtype=jnp.bfloat16) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) + W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) + + def loss_ref(Q, K, W_uq, W_dq, W_k, W_w): + O = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") + return jnp.sum(O.astype(jnp.float32)) + + def loss_hyb(Q, K, W_uq, W_dq, W_k, W_w): + O = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="hybrid") + return jnp.sum(O.astype(jnp.float32)) + + grads_ref = jax.grad(loss_ref, argnums=(0, 1, 2, 3, 4, 5))(Q, K, W_uq, W_dq, W_k, W_w) + grads_hyb = jax.grad(loss_hyb, argnums=(0, 1, 2, 3, 4, 5))(Q, K, W_uq, W_dq, W_k, W_w) + + names = ("dQ", "dK", "dW_uq", "dW_dq", "dW_k", "dW_w") + all_ok = True + for name, gr, gh in zip(names, grads_ref, grads_hyb): + diff = (gr.astype(jnp.float32) - gh.astype(jnp.float32)) + rel = float(jnp.linalg.norm(diff) / + (jnp.linalg.norm(gr.astype(jnp.float32)) + 1e-30)) + ok = rel < 5e-2 + all_ok = all_ok and ok + tag = "OK" if ok else "FAIL" + print(f" {name:<6} shape={str(gh.shape):<22s} rel.err={rel:.2e} [{tag}]") + overall = "OK" if all_ok else "FAIL" + print(f" bwd leading={str(leading_shape):10s} overall=[{overall}]") + + if __name__ == "__main__": print("=== reference vs reference (sanity) ===") _run_test((2, 3), seed=0, backend="reference") print("\n=== hybrid vs reference ===") _run_test((2, 3), seed=100, backend="hybrid") + + print("\n=== indexer_topk vs reference + jax.lax.top_k ===") + _run_topk_test((2, 3), seed=200, k=32) + + print("\n=== backward: hybrid vs jax.grad(reference) ===") + _run_bwd_test((2, 3), seed=300) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index 1a9c517a2..b153b6f18 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -58,4 +58,4 @@ def lowering(ctx, x, **kwargs): from .utils import * from .permutation import * -from .indexer import score_reduce_triton +from .indexer import score_reduce_triton, score_topk_triton diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index e16838cd1..6ddf38298 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -19,6 +19,7 @@ import functools +import jax import jax.numpy as jnp import triton import triton.language as tl @@ -85,8 +86,10 @@ def _score_reduce_kernel( pid_t = tl.program_id(1) pid_bh = tl.program_id(2) - b = pid_bh // oH - h_outer = pid_bh % oH + # int64 indexing — Hq alone has B*oH*T*H*d_i = 4.3 B elements at T=S=4096, + # exceeds int32 range. + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) @@ -176,6 +179,327 @@ def grid_fn(merged_kwargs): mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="cuda") +# --- Backward: dHq + dW_o kernel ---------------------------------------------- +# +# FlashAttention-style: residuals saved from forward = (Hq, Hk, W_o). The +# (T, H, S) score tensor is recomputed inside this kernel from Hq @ Hk^T, +# so we never store H = relu(scores) -- which is 549 GB at the production +# 4096^2 shape. +# +# Math: +# scores[t, h, s] = sum_i Hq[t, h, i] * Hk[s, i] +# H_relu[t, h, s] = max(scores, 0) +# O[t, s] = sum_h H_relu[t, h, s] * W_o[t, h] +# +# Cotangents: +# dW_o[t, h] = sum_s dO[t, s] * H_relu[t, h, s] +# dH[t,h,s] = dO[t, s] * W_o[t, h] +# dscores = dH * (scores > 0) # ReLU mask +# dHq[t,h,i] = sum_s dscores[t,h,s] * Hk[s, i] +# dHk[s,i] = sum_t sum_h dscores[t,h,s] * Hq[t, h, i] +# +# Kernel A (this one): computes dHq and dW_o. +# Grid: (cdiv(T_t, BLOCK_T), B * oH). Each CTA owns BLOCK_T rows of dHq +# and dW_o exclusively -- no atomics needed since the full S range is +# reduced inside one CTA. +# +# Kernel B (next section): computes dHk. Grid (cdiv(T_s, BLOCK_S), B * oH); +# each CTA owns BLOCK_S rows of dHk and reduces over all T inside. + + +@triton.jit +def _score_reduce_dHq_dWo_kernel( + Hq_ptr, # (B, oH, T_t, H, d_i) bf16 + Hk_ptr, # (B, oH, T_s, d_i) bf16 + W_o_ptr, # (B, oH, T_t, H) bf16 + dO_ptr, # (B, oH, T_t, T_s) fp32 (caller upcasts) + dHq_ptr, # (B, oH, T_t, H, d_i) bf16 OUTPUT + dWo_ptr, # (B, oH, T_t, H) bf16 OUTPUT + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Per-CTA: produces dHq[BLOCK_T, H, d_i] and dW_o[BLOCK_T, H]. + + Outer-h loop / inner-s loop. For each h, we accumulate dHq and dW_o + contributions over the full S range, then store and move on. + """ + pid_t = tl.program_id(0) + pid_bh = tl.program_id(1) + + # int64 indexing — production tensors exceed int32 range + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + + rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + rdi = tl.arange(0, d_i) + rt_mask = rt < T_t + + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + do_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) + + for h in range(H): + # Load Hq[..., rt, h, :] -> [BLOCK_T, d_i] bf16 + hq_ptrs = Hq_ptr + hq_base + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :] + Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) + + # Load W_o[..., rt, h] -> [BLOCK_T] fp32 + wo_ptrs = W_o_ptr + wo_base + rt * H + h + w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) + + dHq_acc = tl.zeros((BLOCK_T, d_i), dtype=tl.float32) + dWo_acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for s_start in range(0, T_s, BLOCK_S): + rs = s_start + tl.arange(0, BLOCK_S) + rs_mask = rs < T_s + + # Load Hk[..., rs, :] -> [BLOCK_S, d_i] bf16 + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + + # Load dO[..., rt, rs] -> [BLOCK_T, BLOCK_S] (caller upcast to fp32) + do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] + dO_chunk = tl.load( + do_ptrs, + mask=rt_mask[:, None] & rs_mask[None, :], + other=0.0, + ) + + # Recompute scores[BLOCK_T, BLOCK_S] = Hq_h @ Hk_chunk^T, in fp32 + scores = tl.dot(Hq_h, tl.trans(Hk_chunk)) # [BLOCK_T, BLOCK_S] + relu_mask = scores > 0 + h_relu = tl.where(relu_mask, scores, 0.0) + + # dW_o accumulator: sum_s dO * H_relu + dWo_acc += tl.sum(dO_chunk * h_relu, axis=1) + + # dscores = (dO * w_h) * relu_mask + dH = dO_chunk * w_h[:, None] + dscores = tl.where(relu_mask, dH, 0.0) + + # dHq_acc += dscores @ Hk_chunk -> [BLOCK_T, d_i] + dHq_acc += tl.dot(dscores.to(Hk_chunk.dtype), Hk_chunk) + + # Store dHq[..., rt, h, :] + dhq_ptrs = dHq_ptr + hq_base + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :] + tl.store( + dhq_ptrs, + dHq_acc.to(dHq_ptr.dtype.element_ty), + mask=rt_mask[:, None], + ) + + # Store dW_o[..., rt, h] + dwo_ptrs = dWo_ptr + wo_base + rt * H + h + tl.store(dwo_ptrs, dWo_acc.to(dWo_ptr.dtype.element_ty), mask=rt_mask) + + +_score_reduce_dHq_dWo_p = extend_core.Primitive("te_indexer_score_reduce_dHq_dWo") +_score_reduce_dHq_dWo_p.multiple_results = True + + +@_score_reduce_dHq_dWo_p.def_abstract_eval +def _score_reduce_dHq_dWo_abstract(Hq, Hk, W_o, dO): + del Hk, dO + return [ + core.ShapedArray(Hq.shape, Hq.dtype), # dHq + core.ShapedArray(W_o.shape, W_o.dtype), # dW_o + ] + + +_score_reduce_dHq_dWo_p.def_impl( + functools.partial(xla.apply_primitive, _score_reduce_dHq_dWo_p) +) + + +def _score_reduce_dHq_dWo_lowering(ctx, Hq, Hk, W_o, dO): + Hq_aval = ctx.avals_in[0] + Hk_aval = ctx.avals_in[1] + B, oH, T_t, H, d_i = Hq_aval.shape + T_s = Hk_aval.shape[2] + BLOCK_T = 32 if T_t >= 32 else T_t + BLOCK_S = 32 if T_s >= 32 else T_s + + return triton_call_lowering( + ctx, + _score_reduce_dHq_dWo_kernel, + Hq, Hk, W_o, dO, + grid=(triton.cdiv(T_t, BLOCK_T), B * oH), + num_warps=4, + num_stages=2, + constexprs={ + "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, + "H": H, "d_i": d_i, + "BLOCK_T": BLOCK_T, "BLOCK_S": BLOCK_S, + }, + ) + + +mlir.register_lowering(_score_reduce_dHq_dWo_p, _score_reduce_dHq_dWo_lowering, platform="rocm") +mlir.register_lowering(_score_reduce_dHq_dWo_p, _score_reduce_dHq_dWo_lowering, platform="cuda") + + +# --- Backward: dHk kernel ----------------------------------------------------- + + +@triton.jit +def _score_reduce_dHk_kernel( + Hq_ptr, # (B, oH, T_t, H, d_i) bf16 + Hk_ptr, # (B, oH, T_s, d_i) bf16 + W_o_ptr, # (B, oH, T_t, H) bf16 + dO_ptr, # (B, oH, T_t, T_s) fp32 + dHk_ptr, # (B, oH, T_s, d_i) bf16 OUTPUT + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Per-CTA: produces dHk[BLOCK_S, d_i]. + + Outer-h loop / inner-t loop, accumulating dHk over all T and H. + """ + pid_s = tl.program_id(0) + pid_bh = tl.program_id(1) + + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + + rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) + rdi = tl.arange(0, d_i) + rs_mask = rs < T_s + + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + do_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) + + # Load Hk[..., rs, :] once -- needed for score recompute every iteration + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_tile = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + Hk_T = tl.trans(Hk_tile) # [d_i, BLOCK_S] + + dHk_acc = tl.zeros((BLOCK_S, d_i), dtype=tl.float32) + + for h in range(H): + for t_start in range(0, T_t, BLOCK_T): + rt = t_start + tl.arange(0, BLOCK_T) + rt_mask = rt < T_t + + # Load Hq[..., rt, h, :] -> [BLOCK_T, d_i] + hq_ptrs = Hq_ptr + hq_base + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :] + Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) + + # Load W_o[..., rt, h] -> [BLOCK_T] + wo_ptrs = W_o_ptr + wo_base + rt * H + h + w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) + + # Load dO[..., rt, rs] -> [BLOCK_T, BLOCK_S] + do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] + dO_chunk = tl.load( + do_ptrs, + mask=rt_mask[:, None] & rs_mask[None, :], + other=0.0, + ) + + # Recompute scores[BLOCK_T, BLOCK_S] + scores = tl.dot(Hq_h, Hk_T) + relu_mask = scores > 0 + + # dscores = (dO * w_h[:, None]) * relu_mask + dH = dO_chunk * w_h[:, None] + dscores = tl.where(relu_mask, dH, 0.0) + + # dHk[s, i] += sum_t dscores[t, s] * Hq_h[t, i] + # = (dscores^T @ Hq_h)[s, i] + dHk_acc += tl.dot(tl.trans(dscores).to(Hq_h.dtype), Hq_h) + + dhk_ptrs = dHk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + tl.store( + dhk_ptrs, + dHk_acc.to(dHk_ptr.dtype.element_ty), + mask=rs_mask[:, None], + ) + + +_score_reduce_dHk_p = extend_core.Primitive("te_indexer_score_reduce_dHk") +_score_reduce_dHk_p.multiple_results = True + + +@_score_reduce_dHk_p.def_abstract_eval +def _score_reduce_dHk_abstract(Hq, Hk, W_o, dO): + del Hq, W_o, dO + return [core.ShapedArray(Hk.shape, Hk.dtype)] + + +_score_reduce_dHk_p.def_impl( + functools.partial(xla.apply_primitive, _score_reduce_dHk_p) +) + + +def _score_reduce_dHk_lowering(ctx, Hq, Hk, W_o, dO): + Hq_aval = ctx.avals_in[0] + Hk_aval = ctx.avals_in[1] + B, oH, T_t, H, d_i = Hq_aval.shape + T_s = Hk_aval.shape[2] + BLOCK_T = 32 if T_t >= 32 else T_t + BLOCK_S = 32 if T_s >= 32 else T_s + + return triton_call_lowering( + ctx, + _score_reduce_dHk_kernel, + Hq, Hk, W_o, dO, + grid=(triton.cdiv(T_s, BLOCK_S), B * oH), + num_warps=4, + num_stages=2, + constexprs={ + "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, + "H": H, "d_i": d_i, + "BLOCK_T": BLOCK_T, "BLOCK_S": BLOCK_S, + }, + ) + + +mlir.register_lowering(_score_reduce_dHk_p, _score_reduce_dHk_lowering, platform="rocm") +mlir.register_lowering(_score_reduce_dHk_p, _score_reduce_dHk_lowering, platform="cuda") + + +# --- Public score_reduce_triton with custom_vjp ------------------------------ + + +@functools.partial(jax.custom_vjp, nondiff_argnums=(3,)) +def _score_reduce_with_vjp(Hq, Hk, W_o, out_dtype): + return _score_reduce_p.bind(Hq, Hk, W_o, out_dtype=out_dtype)[0] + + +def _score_reduce_fwd(Hq, Hk, W_o, out_dtype): + out = _score_reduce_p.bind(Hq, Hk, W_o, out_dtype=out_dtype)[0] + return out, (Hq, Hk, W_o) + + +def _score_reduce_bwd(out_dtype, residuals, dO): + del out_dtype + Hq, Hk, W_o = residuals + # Backward kernels work in fp32; upcast dO once. + dO_f32 = dO.astype(jnp.float32) + dHq, dW_o = _score_reduce_dHq_dWo_p.bind(Hq, Hk, W_o, dO_f32) + dHk, = _score_reduce_dHk_p.bind(Hq, Hk, W_o, dO_f32) + return dHq, dHk, dW_o + + +_score_reduce_with_vjp.defvjp(_score_reduce_fwd, _score_reduce_bwd) + + def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): """Triton fused score-matmul + relu + per-(t, h) weighted H-reduction. @@ -187,6 +511,10 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): with a single kernel that holds the per-head score tile in registers, avoiding the (B, oH, T, H, S) HBM round-trip an einsum+XLA chain pays. + Differentiable via two backward kernels (FlashAttention-style: residuals + are just (Hq, Hk, W_o); the (T, H, S) score tensor is recomputed inside + backward, never materialized). + Args: Hq: (B, oH, T_t, H, d_i) Hk: (B, oH, T_s, d_i) @@ -227,6 +555,229 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): if out_dtype is None: out_dtype = Hq.dtype - return _score_reduce_p.bind( - Hq, Hk, W_o, out_dtype=jnp.dtype(out_dtype) - )[0] + return _score_reduce_with_vjp(Hq, Hk, W_o, jnp.dtype(out_dtype)) + + +# --- Streaming top-k variant ---------------------------------------------------- +# +# Same einsum-projected (Hq, Hk, W_o) inputs, but fuses top-k indices into the +# kernel: one CTA per (B, oH, T_t) query token, score row never materialized. +# +# Algorithm (mirrors TileLang dsa_sparse_finetune/indexer_topk_reducesum): +# - Maintain a 2K-sized buffer of (score_bits, index) packed uint64 +# - Stream over T_s in BLOCK_S chunks; each chunk computes BLOCK_S new scores +# - Place chunk into buffer[K:K+BLOCK_S], zero buffer[K+BLOCK_S:2K] +# - tl.sort descending; top half is the running top-K +# - After all chunks: buffer[:K] is the answer +# +# tl.sort returns values only, so we pack (score_bits << 32) | index into uint64. +# Post-ReLU scores are >= 0, so fp32 bit pattern is monotone in value. + + +@triton.jit +def _score_topk_kernel( + Hq_ptr, # (B, oH, T_t, H, d_i) bf16 + Hk_ptr, # (B, oH, T_s, d_i) bf16 + W_o_ptr, # (B, oH, T_t, H) bf16 + Topk_idx_ptr, # (B, oH, T_t, K) int32 OUTPUT + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + K: tl.constexpr, + S_PAD: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Per-CTA: one query token's full top-K via streaming bitonic merge. + + Grid: (T_t, B * oH). + """ + pid_t = tl.program_id(0) + pid_bh = tl.program_id(1) + # int64 indexing — Hq alone has B*oH*T*H*d_i = 4.3 B elements at T=S=4096. + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + pid_t_64 = pid_t.to(tl.int64) + + rh = tl.arange(0, H) + rdi = tl.arange(0, d_i) + + # Pre-load Hq[b, h_outer, pid_t, :, :] -> [H, d_i] once + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + pid_t_64 * (H * d_i) + Hq_token = tl.load(Hq_ptr + hq_base + rh[:, None] * d_i + rdi[None, :]) + + # Pre-load w_o[b, h_outer, pid_t, :] -> [H] once + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + pid_t_64 * H + w_o = tl.load(W_o_ptr + wo_base + rh).to(tl.float32) + + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + + TOP_BUF: tl.constexpr = 2 * K + INNER: tl.constexpr = K // BLOCK_S # chunks per sort + N_OUTER: tl.constexpr = S_PAD // K # number of sorts per CTA + top_packed = tl.zeros((TOP_BUF,), dtype=tl.uint64) + + rs_buf = tl.arange(0, TOP_BUF) + rs_chunk = tl.arange(0, BLOCK_S) + Hq_T = tl.trans(Hq_token) # [d_i, H] + + # Two-level loop: fill the bottom K slots over INNER chunks, then sort. + # Net: N_OUTER sorts instead of S_PAD/BLOCK_S sorts (4x fewer at production + # shape). The previous round's "losers" (bottom-K after each sort) are + # naturally overwritten by the next INNER chunks; correctness holds because + # those losers are by definition below the running top-K threshold. + for o in tl.static_range(N_OUTER): + for i in tl.static_range(INNER): + c = o * INNER + i + s_start = c * BLOCK_S + rs = s_start + rs_chunk + rs_mask = rs < T_s + + # Load Hk_chunk[BLOCK_S, d_i] + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + + # Score matmul: [BLOCK_S, d_i] @ [d_i, H] -> [BLOCK_S, H] + logits = tl.dot(Hk_chunk, Hq_T) + logits = tl.maximum(logits, 0.0) + + # Weighted H-reduce: sum(logits * w_o[None, :], axis=1) -> [BLOCK_S] + # Note: w_o can be negative, so chunk_scores can be negative even after ReLU. + chunk_scores = tl.sum(logits * w_o[None, :], axis=1) + + # Convert fp32 to "sortable uint32" so uint comparison matches fp32 + # comparison across the full sign range: + # positive: flip sign bit + # negative: flip all bits + # See https://stereopsis.com/radix.html + bits = chunk_scores.to(tl.uint32, bitcast=True) + sign = bits >> 31 + flip_mask = (0 - sign.to(tl.int32)).to(tl.uint32) | 0x80000000 + sortable = bits ^ flip_mask + # OOR positions get sortable=0 (smallest possible, sorts to bottom) + sortable = tl.where(rs_mask, sortable, 0) + + # Pack (sortable_score_bits << 32) | index into uint64 + chunk_packed = (sortable.to(tl.uint64) << 32) | rs.to(tl.uint64) + + # Scatter chunk_packed into top_packed[K + i*BLOCK_S : K + (i+1)*BLOCK_S] + chunk_offset = K + i * BLOCK_S + in_chunk_slot = (rs_buf >= chunk_offset) & (rs_buf < chunk_offset + BLOCK_S) + chunk_gather_idx = tl.where(in_chunk_slot, rs_buf - chunk_offset, 0).to(tl.int32) + gathered = tl.gather(chunk_packed, chunk_gather_idx, axis=0) + top_packed = tl.where(in_chunk_slot, gathered, top_packed) + + # All INNER chunks placed -> sort once + top_packed = tl.sort(top_packed, descending=True) + + # Extract top K indices: gather positions [0, K) from the sorted buffer, + # take low 32 bits. + rk = tl.arange(0, K) + top_k_packed = tl.gather(top_packed, rk, axis=0) + top_k_idx = (top_k_packed & 0xFFFFFFFF).to(tl.int32) + + out_base = b * (oH * T_t * K) + h_outer * (T_t * K) + pid_t_64 * K + tl.store(Topk_idx_ptr + out_base + rk, top_k_idx) + + +_score_topk_p = extend_core.Primitive("te_indexer_score_topk_triton") +_score_topk_p.multiple_results = True + + +def _next_pow2(n): + p = 1 + while p < n: + p *= 2 + return p + + +@_score_topk_p.def_abstract_eval +def _score_topk_abstract(Hq, Hk, W_o, *, k): + del Hk, W_o + B, oH, T_t, _H, _d_i = Hq.shape + return [core.ShapedArray((B, oH, T_t, k), jnp.int32)] + + +_score_topk_p.def_impl(functools.partial(xla.apply_primitive, _score_topk_p)) + + +def _score_topk_lowering(ctx, Hq, Hk, W_o, *, k): + Hq_aval = ctx.avals_in[0] + Hk_aval = ctx.avals_in[1] + B, oH, T_t, H, d_i = Hq_aval.shape + T_s = Hk_aval.shape[2] + S_PAD = _next_pow2(T_s) + # BLOCK_S must be <= K (so chunk fits in TOP_BUF[K:K+BLOCK_S]) and + # divide S_PAD evenly. Cap at 128 so the [BLOCK_S, H] fp32 logits + # intermediate stays in registers. + BLOCK_S = min(128, k, S_PAD) + + return triton_call_lowering( + ctx, + _score_topk_kernel, + Hq, Hk, W_o, + grid=(T_t, B * oH), + num_warps=4, + num_stages=2, + constexprs={ + "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, + "H": H, "d_i": d_i, + "K": k, "S_PAD": S_PAD, + "BLOCK_S": BLOCK_S, + }, + ) + + +mlir.register_lowering(_score_topk_p, _score_topk_lowering, platform="rocm") +mlir.register_lowering(_score_topk_p, _score_topk_lowering, platform="cuda") + + +def score_topk_triton(Hq, Hk, W_o, *, k): + """Fused score-relu-reduce + streaming top-k. + + Computes the same scores as ``score_reduce_triton`` but never materializes the + (B, oH, T_t, T_s) score matrix — instead, returns the top-k indices into the + T_s axis directly. + + Args: + Hq: (B, oH, T_t, H, d_i) + Hk: (B, oH, T_s, d_i) + W_o: (B, oH, T_t, H) + k: number of top scores to return per (b, oH, T_t) row. Must be a + power of 2 and <= T_s. + + Returns: + Topk_idx: (B, oH, T_t, k) int32 — top-k indices into T_s axis, in + descending score order. + + Notes: + Streaming: maintains a 2K candidate buffer and bitonic-sorts on each + chunk. For k >> S/8 (e.g., k=S/2), this is algorithmically slower than a + single full-row sort but matches the TileLang reference structure and + generalizes to large S without per-CTA registers scaling with S. + """ + if Hq.ndim != 5: + raise ValueError(f"Hq must be rank-5; got shape {Hq.shape}") + if Hk.ndim != 4: + raise ValueError(f"Hk must be rank-4; got shape {Hk.shape}") + if W_o.ndim != 4: + raise ValueError(f"W_o must be rank-4; got shape {W_o.shape}") + + B, oH, T_t, H, d_i = Hq.shape + Bk, oHk, T_s, d_i_k = Hk.shape + Bw, oHw, T_t_w, H_w = W_o.shape + if (Bk, oHk) != (B, oH): + raise ValueError(f"(B, oH) mismatch: Hq has {(B, oH)}, Hk has {(Bk, oHk)}") + if d_i != d_i_k: + raise ValueError(f"d_i mismatch: Hq has {d_i}, Hk has {d_i_k}") + if (Bw, oHw, T_t_w, H_w) != (B, oH, T_t, H): + raise ValueError(f"W_o shape {W_o.shape} != expected (B, oH, T_t, H)") + + if k <= 0 or (k & (k - 1)) != 0: + raise ValueError(f"k must be a positive power of 2; got {k}") + if k > T_s: + raise ValueError(f"k={k} must be <= T_s={T_s}") + + return _score_topk_p.bind(Hq, Hk, W_o, k=k)[0] From 949aee6873911a627ea513469c8593f764836963 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 20 May 2026 19:00:40 +0000 Subject: [PATCH 07/17] Updated bwd pass to chunked hybrid kernel for mem consideration --- .../jax/triton_extensions/indexer.py | 414 +++++++----------- 1 file changed, 168 insertions(+), 246 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index 6ddf38298..ed6c9277f 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -179,299 +179,170 @@ def grid_fn(merged_kwargs): mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="cuda") -# --- Backward: dHq + dW_o kernel ---------------------------------------------- +# --- Chunked score-tile kernel for hybrid bwd -------------------------------- # -# FlashAttention-style: residuals saved from forward = (Hq, Hk, W_o). The -# (T, H, S) score tensor is recomputed inside this kernel from Hq @ Hk^T, -# so we never store H = relu(scores) -- which is 549 GB at the production -# 4096^2 shape. +# Produces dscores_chunk[B, oH, T, H_CHUNK, T_s] and dW_o_chunk[B, oH, T, H_CHUNK] +# for ONE h-chunk. Caller loops over H/H_CHUNK chunks and feeds dscores_chunk +# to hipBLASLt einsums for dHq/dHk reductions. Bounds peak materialization to +# H/H_CHUNK fraction of the full (B, oH, T, H, T_s) score tensor. # -# Math: -# scores[t, h, s] = sum_i Hq[t, h, i] * Hk[s, i] -# H_relu[t, h, s] = max(scores, 0) -# O[t, s] = sum_h H_relu[t, h, s] * W_o[t, h] -# -# Cotangents: -# dW_o[t, h] = sum_s dO[t, s] * H_relu[t, h, s] -# dH[t,h,s] = dO[t, s] * W_o[t, h] -# dscores = dH * (scores > 0) # ReLU mask -# dHq[t,h,i] = sum_s dscores[t,h,s] * Hk[s, i] -# dHk[s,i] = sum_t sum_h dscores[t,h,s] * Hq[t, h, i] -# -# Kernel A (this one): computes dHq and dW_o. -# Grid: (cdiv(T_t, BLOCK_T), B * oH). Each CTA owns BLOCK_T rows of dHq -# and dW_o exclusively -- no atomics needed since the full S range is -# reduced inside one CTA. -# -# Kernel B (next section): computes dHk. Grid (cdiv(T_s, BLOCK_S), B * oH); -# each CTA owns BLOCK_S rows of dHk and reduces over all T inside. +# Fuses score recompute + relu + mask + dO*W_o broadcast in registers -- +# nothing of size (B, oH, T, H, T_s) ever lands in HBM at full size. dW_o is +# reduced inline (sum_s of h_relu * dO) so h_relu also never materializes. + + +_HBWD_BLOCK_T = 64 +_HBWD_BLOCK_S = 64 @triton.jit -def _score_reduce_dHq_dWo_kernel( - Hq_ptr, # (B, oH, T_t, H, d_i) bf16 - Hk_ptr, # (B, oH, T_s, d_i) bf16 - W_o_ptr, # (B, oH, T_t, H) bf16 - dO_ptr, # (B, oH, T_t, T_s) fp32 (caller upcasts) - dHq_ptr, # (B, oH, T_t, H, d_i) bf16 OUTPUT - dWo_ptr, # (B, oH, T_t, H) bf16 OUTPUT +def _score_dscores_chunk_kernel( + Hq_chunk_ptr, # input (B, oH, T, H_CHUNK, d_i) bf16 + Hk_ptr, # input (B, oH, T_s, d_i) bf16 + W_o_chunk_ptr, # input (B, oH, T, H_CHUNK) bf16 + dO_ptr, # input (B, oH, T, T_s) fp32 + dscores_chunk_ptr, # output (B, oH, T, H_CHUNK, T_s) bf16 + dWo_chunk_ptr, # output (B, oH, T, H_CHUNK) bf16 B: tl.constexpr, oH: tl.constexpr, - T_t: tl.constexpr, + T: tl.constexpr, T_s: tl.constexpr, - H: tl.constexpr, + H_CHUNK: tl.constexpr, d_i: tl.constexpr, BLOCK_T: tl.constexpr, BLOCK_S: tl.constexpr, ): - """Per-CTA: produces dHq[BLOCK_T, H, d_i] and dW_o[BLOCK_T, H]. + """One CTA handles (T_tile, h_in) for one (b, h_outer). Loops over s_chunks. - Outer-h loop / inner-s loop. For each h, we accumulate dHq and dW_o - contributions over the full S range, then store and move on. + Each CTA writes its T_tile rows of (dscores_chunk[..., h_in, :], + dW_o_chunk[..., h_in]). dW_o is reduced in registers (sum over s) so + h_relu never lands in HBM -- we compute it on-the-fly and consume it. """ pid_t = tl.program_id(0) - pid_bh = tl.program_id(1) - - # int64 indexing — production tensors exceed int32 range + pid_h_bh = tl.program_id(1) + h_in = pid_h_bh % H_CHUNK + pid_bh = pid_h_bh // H_CHUNK b = (pid_bh // oH).to(tl.int64) h_outer = (pid_bh % oH).to(tl.int64) + h_in_64 = h_in.to(tl.int64) rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) rdi = tl.arange(0, d_i) - rt_mask = rt < T_t + rt_mask = rt < T - hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + hq_base = b * (oH * T * H_CHUNK * d_i) + h_outer * (T * H_CHUNK * d_i) hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) - wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) - do_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) - - for h in range(H): - # Load Hq[..., rt, h, :] -> [BLOCK_T, d_i] bf16 - hq_ptrs = Hq_ptr + hq_base + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :] - Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) - - # Load W_o[..., rt, h] -> [BLOCK_T] fp32 - wo_ptrs = W_o_ptr + wo_base + rt * H + h - w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) - - dHq_acc = tl.zeros((BLOCK_T, d_i), dtype=tl.float32) - dWo_acc = tl.zeros((BLOCK_T,), dtype=tl.float32) - - for s_start in range(0, T_s, BLOCK_S): - rs = s_start + tl.arange(0, BLOCK_S) - rs_mask = rs < T_s - - # Load Hk[..., rs, :] -> [BLOCK_S, d_i] bf16 - hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] - Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) - - # Load dO[..., rt, rs] -> [BLOCK_T, BLOCK_S] (caller upcast to fp32) - do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] - dO_chunk = tl.load( - do_ptrs, - mask=rt_mask[:, None] & rs_mask[None, :], - other=0.0, - ) - - # Recompute scores[BLOCK_T, BLOCK_S] = Hq_h @ Hk_chunk^T, in fp32 - scores = tl.dot(Hq_h, tl.trans(Hk_chunk)) # [BLOCK_T, BLOCK_S] - relu_mask = scores > 0 - h_relu = tl.where(relu_mask, scores, 0.0) + wo_base = b * (oH * T * H_CHUNK) + h_outer * (T * H_CHUNK) + do_base = b * (oH * T * T_s) + h_outer * (T * T_s) + ds_base = b * (oH * T * H_CHUNK * T_s) + h_outer * (T * H_CHUNK * T_s) + + # Load Hq[..., t_tile, h_in, :] -> [BLOCK_T, d_i] once per CTA + hq_ptrs = (Hq_chunk_ptr + hq_base + + rt[:, None] * (H_CHUNK * d_i) + + h_in_64 * d_i + + rdi[None, :]) + Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) + + # Load W_o[..., t_tile, h_in] -> [BLOCK_T] once per CTA + wo_ptrs = W_o_chunk_ptr + wo_base + rt * H_CHUNK + h_in_64 + w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) + + # dW_o accumulator: sum_s (h_relu * dO) -- reduced in regs + dWo_acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for s_start in range(0, T_s, BLOCK_S): + rs = s_start + tl.arange(0, BLOCK_S) + rs_mask = rs < T_s + + # Load Hk[..., s_chunk, :] and dO[..., t_tile, s_chunk] + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + + do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] + dO_chunk = tl.load( + do_ptrs, + mask=rt_mask[:, None] & rs_mask[None, :], + other=0.0, + ) - # dW_o accumulator: sum_s dO * H_relu - dWo_acc += tl.sum(dO_chunk * h_relu, axis=1) + # scores tile in registers (never lands in HBM at full size) + scores = tl.dot(Hq_h, tl.trans(Hk_chunk)) + relu_mask = scores > 0 + h_relu = tl.where(relu_mask, scores, 0.0) - # dscores = (dO * w_h) * relu_mask - dH = dO_chunk * w_h[:, None] - dscores = tl.where(relu_mask, dH, 0.0) + # dW_o contribution: sum_s (h_relu * dO) + dWo_acc += tl.sum(h_relu * dO_chunk, axis=1) - # dHq_acc += dscores @ Hk_chunk -> [BLOCK_T, d_i] - dHq_acc += tl.dot(dscores.to(Hk_chunk.dtype), Hk_chunk) + # dscores tile = relu_mask * (dO * W_o) + dscores = tl.where(relu_mask, dO_chunk * w_h[:, None], 0.0) - # Store dHq[..., rt, h, :] - dhq_ptrs = dHq_ptr + hq_base + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :] + # Store dscores tile to HBM (bf16). Total dscores_chunk size is + # H_CHUNK x smaller than the full (B,oH,T,H,T_s) tensor. + ds_ptrs = (dscores_chunk_ptr + ds_base + + rt[:, None] * (H_CHUNK * T_s) + + h_in_64 * T_s + + rs[None, :]) tl.store( - dhq_ptrs, - dHq_acc.to(dHq_ptr.dtype.element_ty), - mask=rt_mask[:, None], + ds_ptrs, + dscores.to(dscores_chunk_ptr.dtype.element_ty), + mask=rt_mask[:, None] & rs_mask[None, :], ) - # Store dW_o[..., rt, h] - dwo_ptrs = dWo_ptr + wo_base + rt * H + h - tl.store(dwo_ptrs, dWo_acc.to(dWo_ptr.dtype.element_ty), mask=rt_mask) - - -_score_reduce_dHq_dWo_p = extend_core.Primitive("te_indexer_score_reduce_dHq_dWo") -_score_reduce_dHq_dWo_p.multiple_results = True - - -@_score_reduce_dHq_dWo_p.def_abstract_eval -def _score_reduce_dHq_dWo_abstract(Hq, Hk, W_o, dO): - del Hk, dO - return [ - core.ShapedArray(Hq.shape, Hq.dtype), # dHq - core.ShapedArray(W_o.shape, W_o.dtype), # dW_o - ] - - -_score_reduce_dHq_dWo_p.def_impl( - functools.partial(xla.apply_primitive, _score_reduce_dHq_dWo_p) -) - - -def _score_reduce_dHq_dWo_lowering(ctx, Hq, Hk, W_o, dO): - Hq_aval = ctx.avals_in[0] - Hk_aval = ctx.avals_in[1] - B, oH, T_t, H, d_i = Hq_aval.shape - T_s = Hk_aval.shape[2] - BLOCK_T = 32 if T_t >= 32 else T_t - BLOCK_S = 32 if T_s >= 32 else T_s - - return triton_call_lowering( - ctx, - _score_reduce_dHq_dWo_kernel, - Hq, Hk, W_o, dO, - grid=(triton.cdiv(T_t, BLOCK_T), B * oH), - num_warps=4, - num_stages=2, - constexprs={ - "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, - "H": H, "d_i": d_i, - "BLOCK_T": BLOCK_T, "BLOCK_S": BLOCK_S, - }, - ) - - -mlir.register_lowering(_score_reduce_dHq_dWo_p, _score_reduce_dHq_dWo_lowering, platform="rocm") -mlir.register_lowering(_score_reduce_dHq_dWo_p, _score_reduce_dHq_dWo_lowering, platform="cuda") - - -# --- Backward: dHk kernel ----------------------------------------------------- - - -@triton.jit -def _score_reduce_dHk_kernel( - Hq_ptr, # (B, oH, T_t, H, d_i) bf16 - Hk_ptr, # (B, oH, T_s, d_i) bf16 - W_o_ptr, # (B, oH, T_t, H) bf16 - dO_ptr, # (B, oH, T_t, T_s) fp32 - dHk_ptr, # (B, oH, T_s, d_i) bf16 OUTPUT - B: tl.constexpr, - oH: tl.constexpr, - T_t: tl.constexpr, - T_s: tl.constexpr, - H: tl.constexpr, - d_i: tl.constexpr, - BLOCK_T: tl.constexpr, - BLOCK_S: tl.constexpr, -): - """Per-CTA: produces dHk[BLOCK_S, d_i]. - - Outer-h loop / inner-t loop, accumulating dHk over all T and H. - """ - pid_s = tl.program_id(0) - pid_bh = tl.program_id(1) - - b = (pid_bh // oH).to(tl.int64) - h_outer = (pid_bh % oH).to(tl.int64) - - rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) - rdi = tl.arange(0, d_i) - rs_mask = rs < T_s - - hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) - hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) - wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) - do_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) - - # Load Hk[..., rs, :] once -- needed for score recompute every iteration - hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] - Hk_tile = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) - Hk_T = tl.trans(Hk_tile) # [d_i, BLOCK_S] - - dHk_acc = tl.zeros((BLOCK_S, d_i), dtype=tl.float32) - - for h in range(H): - for t_start in range(0, T_t, BLOCK_T): - rt = t_start + tl.arange(0, BLOCK_T) - rt_mask = rt < T_t - - # Load Hq[..., rt, h, :] -> [BLOCK_T, d_i] - hq_ptrs = Hq_ptr + hq_base + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :] - Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) - - # Load W_o[..., rt, h] -> [BLOCK_T] - wo_ptrs = W_o_ptr + wo_base + rt * H + h - w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) - - # Load dO[..., rt, rs] -> [BLOCK_T, BLOCK_S] - do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] - dO_chunk = tl.load( - do_ptrs, - mask=rt_mask[:, None] & rs_mask[None, :], - other=0.0, - ) - - # Recompute scores[BLOCK_T, BLOCK_S] - scores = tl.dot(Hq_h, Hk_T) - relu_mask = scores > 0 - - # dscores = (dO * w_h[:, None]) * relu_mask - dH = dO_chunk * w_h[:, None] - dscores = tl.where(relu_mask, dH, 0.0) - - # dHk[s, i] += sum_t dscores[t, s] * Hq_h[t, i] - # = (dscores^T @ Hq_h)[s, i] - dHk_acc += tl.dot(tl.trans(dscores).to(Hq_h.dtype), Hq_h) - - dhk_ptrs = dHk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + # Store dW_o[..., t_tile, h_in] + dwo_out_ptrs = dWo_chunk_ptr + wo_base + rt * H_CHUNK + h_in_64 tl.store( - dhk_ptrs, - dHk_acc.to(dHk_ptr.dtype.element_ty), - mask=rs_mask[:, None], + dwo_out_ptrs, + dWo_acc.to(dWo_chunk_ptr.dtype.element_ty), + mask=rt_mask, ) -_score_reduce_dHk_p = extend_core.Primitive("te_indexer_score_reduce_dHk") -_score_reduce_dHk_p.multiple_results = True +_score_dscores_chunk_p = extend_core.Primitive("te_indexer_score_dscores_chunk") +_score_dscores_chunk_p.multiple_results = True -@_score_reduce_dHk_p.def_abstract_eval -def _score_reduce_dHk_abstract(Hq, Hk, W_o, dO): - del Hq, W_o, dO - return [core.ShapedArray(Hk.shape, Hk.dtype)] +@_score_dscores_chunk_p.def_abstract_eval +def _score_dscores_chunk_abstract(Hq_chunk, Hk, W_o_chunk, dO): + del Hk, W_o_chunk + B, oH, T, H_CHUNK, _ = Hq_chunk.shape + T_s = dO.shape[-1] + return [ + core.ShapedArray((B, oH, T, H_CHUNK, T_s), Hq_chunk.dtype), # dscores + core.ShapedArray((B, oH, T, H_CHUNK), Hq_chunk.dtype), # dW_o + ] -_score_reduce_dHk_p.def_impl( - functools.partial(xla.apply_primitive, _score_reduce_dHk_p) +_score_dscores_chunk_p.def_impl( + functools.partial(xla.apply_primitive, _score_dscores_chunk_p) ) -def _score_reduce_dHk_lowering(ctx, Hq, Hk, W_o, dO): +def _score_dscores_chunk_lowering(ctx, Hq_chunk, Hk, W_o_chunk, dO): Hq_aval = ctx.avals_in[0] - Hk_aval = ctx.avals_in[1] - B, oH, T_t, H, d_i = Hq_aval.shape - T_s = Hk_aval.shape[2] - BLOCK_T = 32 if T_t >= 32 else T_t - BLOCK_S = 32 if T_s >= 32 else T_s + dO_aval = ctx.avals_in[3] + B, oH, T, H_CHUNK, d_i = Hq_aval.shape + T_s = dO_aval.shape[-1] + BLOCK_T = _HBWD_BLOCK_T if T >= _HBWD_BLOCK_T else T + BLOCK_S = _HBWD_BLOCK_S if T_s >= _HBWD_BLOCK_S else T_s + n_t_tiles = (T + BLOCK_T - 1) // BLOCK_T return triton_call_lowering( ctx, - _score_reduce_dHk_kernel, - Hq, Hk, W_o, dO, - grid=(triton.cdiv(T_s, BLOCK_S), B * oH), + _score_dscores_chunk_kernel, + Hq_chunk, Hk, W_o_chunk, dO, + grid=(n_t_tiles, B * oH * H_CHUNK), num_warps=4, num_stages=2, constexprs={ - "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, - "H": H, "d_i": d_i, + "B": B, "oH": oH, "T": T, "T_s": T_s, + "H_CHUNK": H_CHUNK, "d_i": d_i, "BLOCK_T": BLOCK_T, "BLOCK_S": BLOCK_S, }, ) -mlir.register_lowering(_score_reduce_dHk_p, _score_reduce_dHk_lowering, platform="rocm") -mlir.register_lowering(_score_reduce_dHk_p, _score_reduce_dHk_lowering, platform="cuda") +mlir.register_lowering(_score_dscores_chunk_p, _score_dscores_chunk_lowering, platform="rocm") +mlir.register_lowering(_score_dscores_chunk_p, _score_dscores_chunk_lowering, platform="cuda") # --- Public score_reduce_triton with custom_vjp ------------------------------ @@ -487,14 +358,65 @@ def _score_reduce_fwd(Hq, Hk, W_o, out_dtype): return out, (Hq, Hk, W_o) +_BWD_H_CHUNK = 8 # peak (B, oH, T, H_CHUNK, T_s) tile -- bounds materialization + + def _score_reduce_bwd(out_dtype, residuals, dO): del out_dtype Hq, Hk, W_o = residuals - # Backward kernels work in fp32; upcast dO once. - dO_f32 = dO.astype(jnp.float32) - dHq, dW_o = _score_reduce_dHq_dWo_p.bind(Hq, Hk, W_o, dO_f32) - dHk, = _score_reduce_dHk_p.bind(Hq, Hk, W_o, dO_f32) - return dHq, dHk, dW_o + B, oH, T, H, d_i = Hq.shape + + # Hybrid scheme with bounded materialization: + # For each h-chunk of size H_CHUNK (driven by lax.scan, NOT Python + # unroll, so intermediates are freed between iterations): + # 1. Triton kernel fuses (score recompute + relu + mask + dO*W_o + # broadcast) and writes dscores_chunk[B,oH,T,H_CHUNK,T_s] to HBM. + # h_relu is consumed in-register to also produce dWo_chunk + # without ever materializing the (B,oH,T,H,T_s) h_relu tensor. + # 2. hipBLASLt einsums on dscores_chunk give dHq_chunk and a partial + # dHk contribution. + # Peak HBM intermediate stays at H_CHUNK/H fraction of the full score. + # + # The fully-fused Triton bwd variants (v2/v3/v4) remain in this file for + # reference -- they don't materialize the score tensor either but are + # slower than the hipBLASLt-based reductions used here (~2x at 4096^2). + if H % _BWD_H_CHUNK == 0: + H_CHUNK = _BWD_H_CHUNK + else: + H_CHUNK = 1 + for c in (8, 4, 2): + if H % c == 0: + H_CHUNK = c + break + n_chunks = H // H_CHUNK + + Hq_r = Hq.reshape(B, oH, T, n_chunks, H_CHUNK, d_i) + Wo_r = W_o.reshape(B, oH, T, n_chunks, H_CHUNK) + # Move chunk axis to leading for scan over axis 0. + Hq_s = jnp.moveaxis(Hq_r, -3, 0) # (n_chunks, B, oH, T, H_CHUNK, d_i) + Wo_s = jnp.moveaxis(Wo_r, -2, 0) # (n_chunks, B, oH, T, H_CHUNK) + + def step(dHk_acc, chunk): + Hq_c, Wo_c = chunk + # Triton: dscores_chunk + dWo_chunk; no full (B,oH,T,H,T_s) tensor + # ever exists in HBM. + dscores_c, dWo_c = _score_dscores_chunk_p.bind(Hq_c, Hk, Wo_c, dO) + dHq_c = jnp.einsum("...ths,...si->...thi", dscores_c, Hk) + dHk_c = jnp.einsum("...ths,...thi->...si", dscores_c, Hq_c) + new_dHk_acc = dHk_acc + dHk_c.astype(jnp.float32) + return new_dHk_acc, (dHq_c, dWo_c) + + init = jnp.zeros(Hk.shape, dtype=jnp.float32) + dHk_acc, (dHq_chunks, dWo_chunks) = jax.lax.scan( + step, init, (Hq_s, Wo_s), + ) + # dHq_chunks: (n_chunks, B, oH, T, H_CHUNK, d_i) + # dWo_chunks: (n_chunks, B, oH, T, H_CHUNK) + dHq = jnp.moveaxis(dHq_chunks, 0, -3).reshape(B, oH, T, H, d_i) + dWo = jnp.moveaxis(dWo_chunks, 0, -2).reshape(B, oH, T, H) + dHk = dHk_acc.astype(Hk.dtype) + + return dHq.astype(Hq.dtype), dHk, dWo.astype(W_o.dtype) _score_reduce_with_vjp.defvjp(_score_reduce_fwd, _score_reduce_bwd) From 37a856341b2bbc0c986b90e952332708889de06b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 27 May 2026 16:01:46 +0000 Subject: [PATCH 08/17] Added T-tiling to fused score top-k op --- .../jax/triton_extensions/indexer.py | 244 +++++++++++++----- 1 file changed, 184 insertions(+), 60 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index ed6c9277f..2c012b685 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -496,6 +496,39 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): # Post-ReLU scores are >= 0, so fp32 bit pattern is monotone in value. +# Autotune sweep for _score_topk_kernel. +# +# BLOCK_T: number of query tokens per CTA. BLOCK_T>1 amortizes the Hk_chunk +# load across BLOCK_T queries — the single biggest lever at large T_s. At +# BLOCK_T=1 (original), each CTA reloads all of Hk for its (b, oH) slab, +# causing L2 thrash. BLOCK_T=2 halves Hk HBM traffic; BLOCK_T=4 quarters it, +# but grows per-CTA register pressure (Hq_token, top_packed, logits all +# scale with BLOCK_T). +# +# BLOCK_S knobs the inner-chunk size; bigger BLOCK_S = better matmul +# arithmetic intensity, but bigger per-CTA transient footprint +# (logits[BLOCK_S, BLOCK_T*H] fp32 + Hk_chunk[BLOCK_S, d_i] bf16). +# +# Constraint: BLOCK_S must divide K (so INNER = K // BLOCK_S is an integer +# >= 1). Configs whose BLOCK_S exceeds K or doesn't divide K are filtered +# out at lowering time — otherwise jaxlib's autotuner would time them as +# zero-work (fast) and pick a bogus winner that returns all-zero indices. +_SCORE_TOPK_CONFIGS = [ + triton.Config({"BLOCK_S": bs, "BLOCK_T": bt}, num_warps=nw, num_stages=ns) + for bt in (1, 2) + for bs in (32, 64, 128, 256) + for nw in (4, 8) + for ns in (1, 2) +] + [ + # BLOCK_T=4 only at smaller BLOCK_S — at BLOCK_S=256 the logits + # intermediate [256, 4*H=256] fp32 = 256 KB overflows reliably. + triton.Config({"BLOCK_S": bs, "BLOCK_T": 4}, num_warps=nw, num_stages=ns) + for bs in (32, 64, 128) + for nw in (4, 8) + for ns in (1, 2) +] + + @triton.jit def _score_topk_kernel( Hq_ptr, # (B, oH, T_t, H, d_i) bf16 @@ -511,97 +544,166 @@ def _score_topk_kernel( K: tl.constexpr, S_PAD: tl.constexpr, BLOCK_S: tl.constexpr, + BLOCK_T: tl.constexpr, ): - """Per-CTA: one query token's full top-K via streaming bitonic merge. - - Grid: (T_t, B * oH). + """Per-CTA: BLOCK_T consecutive query tokens, all sharing Hk loads. + + Grid: (cdiv(T_t, BLOCK_T), B * oH). Each CTA does: + - Pre-load Hq[..., rt, :, :] for BLOCK_T contiguous query tokens + - For each S chunk: load Hk_chunk ONCE, do one [BLOCK_S, d_i] @ + [d_i, BLOCK_T*H] matmul, weighted-H-reduce per T + - Maintain a single 1D top buffer of size BLOCK_T*2K, with T encoded + in the top 8 bits of each packed entry. After global sort desc, + per-T entries stay grouped together so per-T top-K can be sliced + from fixed offsets. + + Note on layout (1D vs 2D top buffer): + A 2D [BLOCK_T, 2K] top buffer with per-row sort is the natural + design, but `tl.gather + tl.sort(dim=1)` on uint64 2D tensors trips + `TritonGPUOptimizeThreadLocality` on the AMD backend (gfx950, Triton + 3.4.0). The 1D-with-encoded-T workaround sidesteps this — it pays a + ~1.5x sort-cost penalty (one sort of BLOCK_T*2K vs BLOCK_T sorts of + 2K) for BLOCK_T=2, but unblocks Hk-load amortization across queries. """ pid_t = tl.program_id(0) pid_bh = tl.program_id(1) - # int64 indexing — Hq alone has B*oH*T*H*d_i = 4.3 B elements at T=S=4096. + # int64 indexing — Hq has B*oH*T*H*d_i = 4.3 B elements at T=S=4096. b = (pid_bh // oH).to(tl.int64) h_outer = (pid_bh % oH).to(tl.int64) - pid_t_64 = pid_t.to(tl.int64) rh = tl.arange(0, H) rdi = tl.arange(0, d_i) + rs_chunk = tl.arange(0, BLOCK_S) + rk = tl.arange(0, K) + rt_local = tl.arange(0, BLOCK_T) + + rt = pid_t * BLOCK_T + rt_local + rt_64 = rt.to(tl.int64) + rt_mask = rt < T_t + + # Load Hq[b, h_outer, rt, :, :] -> [BLOCK_T, H, d_i]. + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + Hq_token = tl.load( + Hq_ptr + hq_base + + rt_64[:, None, None] * (H * d_i) + + rh[None, :, None] * d_i + + rdi[None, None, :], + mask=rt_mask[:, None, None], + other=0.0, + ) - # Pre-load Hq[b, h_outer, pid_t, :, :] -> [H, d_i] once - hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + pid_t_64 * (H * d_i) - Hq_token = tl.load(Hq_ptr + hq_base + rh[:, None] * d_i + rdi[None, :]) + # Load w_o[b, h_outer, rt, :] -> [BLOCK_T, H] + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + w_o = tl.load( + W_o_ptr + wo_base + rt_64[:, None] * H + rh[None, :], + mask=rt_mask[:, None], + other=0.0, + ).to(tl.float32) - # Pre-load w_o[b, h_outer, pid_t, :] -> [H] once - wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + pid_t_64 * H - w_o = tl.load(W_o_ptr + wo_base + rh).to(tl.float32) + # Flatten Hq for one big matmul per Hk_chunk: [BLOCK_T * H, d_i] -> trans + Hq_flat = tl.reshape(Hq_token, (BLOCK_T * H, d_i)) + Hq_T = tl.trans(Hq_flat) # [d_i, BLOCK_T * H] + w_o_flat = tl.reshape(w_o, (BLOCK_T * H,)) hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) TOP_BUF: tl.constexpr = 2 * K INNER: tl.constexpr = K // BLOCK_S # chunks per sort N_OUTER: tl.constexpr = S_PAD // K # number of sorts per CTA - top_packed = tl.zeros((TOP_BUF,), dtype=tl.uint64) + BIG_BUF: tl.constexpr = BLOCK_T * TOP_BUF + + # Initialize 1D top buffer with t-encoding pre-applied so per-T regions + # stay grouped after global sort. Each slot at position rb gets: + # t_pos = rb // TOP_BUF -> which T this slot belongs to + # t_enc = BLOCK_T - t_pos -> 1..BLOCK_T (never 0 → never collides with + # reserved init pattern) + # packed = (t_enc << 56) | 0 -> score=0 (sortable=0), index=0 + # Real candidates also get tagged with their t_enc; after global sort + # desc, all entries with t_enc=BLOCK_T (i.e. t=0) come first, then + # t_enc=BLOCK_T-1, etc. Within each t group, ordered by score then index. + rb = tl.arange(0, BIG_BUF) + rb_t = rb // TOP_BUF # [BIG_BUF] in [0, BLOCK_T) + rb_pos = rb % TOP_BUF # [BIG_BUF] in [0, TOP_BUF) + t_enc_per_slot = (BLOCK_T - rb_t).to(tl.uint64) + top_packed = t_enc_per_slot << 56 + + # Pre-compute the per-slot (t, pos)-to-flat-chunk-index map used in + # scatter: for each rb, identify the (t, j) in chunk_packed_flat to pull + # from. j depends on `chunk_offset` (varies per inner iter), so the + # gather index is recomputed each iter. - rs_buf = tl.arange(0, TOP_BUF) - rs_chunk = tl.arange(0, BLOCK_S) - Hq_T = tl.trans(Hq_token) # [d_i, H] - - # Two-level loop: fill the bottom K slots over INNER chunks, then sort. - # Net: N_OUTER sorts instead of S_PAD/BLOCK_S sorts (4x fewer at production - # shape). The previous round's "losers" (bottom-K after each sort) are - # naturally overwritten by the next INNER chunks; correctness holds because - # those losers are by definition below the running top-K threshold. for o in tl.static_range(N_OUTER): for i in tl.static_range(INNER): c = o * INNER + i s_start = c * BLOCK_S - rs = s_start + rs_chunk + rs = s_start + rs_chunk # [BLOCK_S] rs_mask = rs < T_s - # Load Hk_chunk[BLOCK_S, d_i] + # Load Hk_chunk[BLOCK_S, d_i] ONCE — shared across BLOCK_T queries. hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) - # Score matmul: [BLOCK_S, d_i] @ [d_i, H] -> [BLOCK_S, H] + # One big matmul: [BLOCK_S, d_i] @ [d_i, BLOCK_T*H] -> [BLOCK_S, BLOCK_T*H] logits = tl.dot(Hk_chunk, Hq_T) logits = tl.maximum(logits, 0.0) - # Weighted H-reduce: sum(logits * w_o[None, :], axis=1) -> [BLOCK_S] - # Note: w_o can be negative, so chunk_scores can be negative even after ReLU. - chunk_scores = tl.sum(logits * w_o[None, :], axis=1) + # Weighted reduce over H per (s, t): + # chunk_scores[s, t] = sum_h logits[s, t*H + h] * w_o[t, h] + weighted = logits * w_o_flat[None, :] + weighted_3d = tl.reshape(weighted, (BLOCK_S, BLOCK_T, H)) + chunk_scores = tl.sum(weighted_3d, axis=2) # [BLOCK_S, BLOCK_T] + chunk_scores_T = tl.trans(chunk_scores) # [BLOCK_T, BLOCK_S] - # Convert fp32 to "sortable uint32" so uint comparison matches fp32 - # comparison across the full sign range: - # positive: flip sign bit - # negative: flip all bits + # Radix-flip: fp32 bit pattern -> sortable uint32 across full sign + # range (positives: flip sign bit; negatives: flip all bits). # See https://stereopsis.com/radix.html - bits = chunk_scores.to(tl.uint32, bitcast=True) + bits = chunk_scores_T.to(tl.uint32, bitcast=True) sign = bits >> 31 flip_mask = (0 - sign.to(tl.int32)).to(tl.uint32) | 0x80000000 sortable = bits ^ flip_mask - # OOR positions get sortable=0 (smallest possible, sorts to bottom) - sortable = tl.where(rs_mask, sortable, 0) - - # Pack (sortable_score_bits << 32) | index into uint64 - chunk_packed = (sortable.to(tl.uint64) << 32) | rs.to(tl.uint64) - - # Scatter chunk_packed into top_packed[K + i*BLOCK_S : K + (i+1)*BLOCK_S] + sortable = tl.where(rs_mask[None, :], sortable, 0) + + # Pack: (t_enc<<56) | (sortable<<24) | (index in low 24 bits). + # 24-bit index supports T_s up to 16M, far above our regime. + t_enc_chunk = (BLOCK_T - rt_local).to(tl.uint64) # [BLOCK_T] + rs_2d = tl.broadcast_to(rs[None, :], (BLOCK_T, BLOCK_S)) + chunk_packed_2d = ( + (t_enc_chunk[:, None] << 56) + | (sortable.to(tl.uint64) << 24) + | rs_2d.to(tl.uint64) + ) # [BLOCK_T, BLOCK_S] + # Flatten to 1D for the scatter (1D gather + 1D sort sidesteps + # the AMD-backend bug with 2D gather+sort combos). + chunk_packed_flat = tl.reshape(chunk_packed_2d, (BLOCK_T * BLOCK_S,)) + + # Scatter into top_packed[t*TOP_BUF + K+i*BLOCK_S : ...] for each t. + # For each rb in [0, BIG_BUF): + # t = rb // TOP_BUF + # pos = rb % TOP_BUF + # in_slot = (pos >= K + i*BLOCK_S) & (pos < K + (i+1)*BLOCK_S) + # flat_idx = t * BLOCK_S + (pos - (K + i*BLOCK_S)) chunk_offset = K + i * BLOCK_S - in_chunk_slot = (rs_buf >= chunk_offset) & (rs_buf < chunk_offset + BLOCK_S) - chunk_gather_idx = tl.where(in_chunk_slot, rs_buf - chunk_offset, 0).to(tl.int32) - gathered = tl.gather(chunk_packed, chunk_gather_idx, axis=0) - top_packed = tl.where(in_chunk_slot, gathered, top_packed) + in_slot = (rb_pos >= chunk_offset) & (rb_pos < chunk_offset + BLOCK_S) + j = rb_pos - chunk_offset + flat_idx = tl.where(in_slot, rb_t * BLOCK_S + j, 0).to(tl.int32) + gathered = tl.gather(chunk_packed_flat, flat_idx, axis=0) + top_packed = tl.where(in_slot, gathered, top_packed) - # All INNER chunks placed -> sort once + # 1D sort of the entire buffer. Per-T regions stay grouped via t_enc. top_packed = tl.sort(top_packed, descending=True) - # Extract top K indices: gather positions [0, K) from the sorted buffer, - # take low 32 bits. - rk = tl.arange(0, K) - top_k_packed = tl.gather(top_packed, rk, axis=0) - top_k_idx = (top_k_packed & 0xFFFFFFFF).to(tl.int32) + # Extract per-T top K. After sort desc, t=0's top K is at positions + # [0, K), t=1's at [TOP_BUF, TOP_BUF+K), etc. — i.e. base = t*TOP_BUF. + out_idx = rt_local[:, None] * TOP_BUF + rk[None, :] # [BLOCK_T, K] + out_idx_flat = tl.reshape(out_idx, (BLOCK_T * K,)).to(tl.int32) + top_k_packed_flat = tl.gather(top_packed, out_idx_flat, axis=0) + top_k_packed = tl.reshape(top_k_packed_flat, (BLOCK_T, K)) + # Strip the t_enc and sortable bits, keep low 24 bits (index). + top_k_idx = (top_k_packed & 0xFFFFFF).to(tl.int32) - out_base = b * (oH * T_t * K) + h_outer * (T_t * K) + pid_t_64 * K - tl.store(Topk_idx_ptr + out_base + rk, top_k_idx) + out_base = b * (oH * T_t * K) + h_outer * (T_t * K) + out_ptrs = Topk_idx_ptr + out_base + rt_64[:, None] * K + rk[None, :] + tl.store(out_ptrs, top_k_idx, mask=rt_mask[:, None]) _score_topk_p = extend_core.Primitive("te_indexer_score_topk_triton") @@ -631,23 +733,45 @@ def _score_topk_lowering(ctx, Hq, Hk, W_o, *, k): B, oH, T_t, H, d_i = Hq_aval.shape T_s = Hk_aval.shape[2] S_PAD = _next_pow2(T_s) - # BLOCK_S must be <= K (so chunk fits in TOP_BUF[K:K+BLOCK_S]) and - # divide S_PAD evenly. Cap at 128 so the [BLOCK_S, H] fp32 logits - # intermediate stays in registers. - BLOCK_S = min(128, k, S_PAD) + + # Build a K-filtered autotuner around the plain JIT kernel. We do this at + # lowering time (rather than decorating the kernel at definition) because + # configs with BLOCK_S > k or BLOCK_S that doesn't divide k would compile + # to a kernel where INNER = k // BLOCK_S = 0 — i.e. a no-op that's fastest + # in the autotune timing race. Filtering ensures the runtime picker only + # sees configs that actually do the work. + # + # Also filter BLOCK_T configs that don't evenly divide T_t — we mask the + # tail but unnecessary padding hurts L1/L2 efficiency. + valid_configs = [ + c for c in _SCORE_TOPK_CONFIGS + if c.kwargs["BLOCK_S"] <= k + and k % c.kwargs["BLOCK_S"] == 0 + and T_t % c.kwargs["BLOCK_T"] == 0 + ] + if not valid_configs: + raise ValueError( + f"No valid BLOCK_S/BLOCK_T config for k={k}, T_t={T_t}" + ) + + autotuned_kernel = triton.autotune( + configs=valid_configs, + key=["H", "d_i", "T_s", "K"], + )(_score_topk_kernel) + + def grid_fn(merged_kwargs): + bt = merged_kwargs.get("BLOCK_T", 1) + return (triton.cdiv(T_t, bt), B * oH) return triton_call_lowering( ctx, - _score_topk_kernel, + autotuned_kernel, Hq, Hk, W_o, - grid=(T_t, B * oH), - num_warps=4, - num_stages=2, + grid=grid_fn, constexprs={ "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, "H": H, "d_i": d_i, "K": k, "S_PAD": S_PAD, - "BLOCK_S": BLOCK_S, }, ) From 22b168f51c8110667abc49e0b71f7df6a287e6ee Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 27 May 2026 21:25:25 +0000 Subject: [PATCH 09/17] Added initial API and tests --- tests/jax/test_sparse_attention.py | 326 +++++++++++ .../jax/compressed_attention.py | 128 +++++ transformer_engine/jax/sparse_attention.py | 542 ++++++++++++++++++ .../jax/triton_extensions/__init__.py | 1 + .../jax/triton_extensions/sparse_attention.py | 108 ++++ 5 files changed, 1105 insertions(+) create mode 100644 tests/jax/test_sparse_attention.py create mode 100644 transformer_engine/jax/compressed_attention.py create mode 100644 transformer_engine/jax/sparse_attention.py create mode 100644 transformer_engine/jax/triton_extensions/sparse_attention.py diff --git a/tests/jax/test_sparse_attention.py b/tests/jax/test_sparse_attention.py new file mode 100644 index 000000000..38b65430b --- /dev/null +++ b/tests/jax/test_sparse_attention.py @@ -0,0 +1,326 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Tests for Deep Sparse Attention (DSA) composition + HCA / fused scaffold contracts.""" + +import jax +import jax.numpy as jnp +import pytest +from flax import linen as nn + +from transformer_engine.jax.sparse_attention import ( + DeepSparseAttention, + deep_sparse_attention_core, + _causal_keep_mask, + _topk_indices_to_attn_mask, + _ref_dsa_jax, +) +from transformer_engine.jax.compressed_attention import ( + HeavilyCompressedAttention, + heavily_compressed_attention, +) +from transformer_engine.jax.triton_extensions import fused_sparse_attention_triton + + +@pytest.fixture(autouse=True) +def _force_unfused_attn(monkeypatch): + """Override conftest's enable_fused_attn_after_hopper for this module. + + The DSA composition path uses an arbitrary topk-derived attention mask. The + fused-attention backends on some platforms restrict mask semantics (padding- + style only). Force the unfused softmax path so reference comparisons hold. + Production callers can still set NVTE_FUSED_ATTN=1 — these tests are only + asserting the composition math, not the fused-path's mask handling. + """ + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + yield + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _make_dsa_module(*, oH=4, D=8, iH=2, idc=16, idi=16, k=4, + backend="composition", indexer_backend="hybrid"): + return DeepSparseAttention( + head_dim=D, + num_attention_heads=oH, + indexer_num_heads=iH, + indexer_d_c=idc, + indexer_d_i=idi, + topk=k, + backend=backend, + indexer_backend=indexer_backend, + dtype=jnp.bfloat16, + ) + + +def _make_inputs(B=1, oH=4, T=16, hidden=32, dtype=jnp.bfloat16, seed=0): + """Rank-4 inputs [B, oH, T, hidden].""" + return jax.random.normal(jax.random.PRNGKey(seed), (B, oH, T, hidden), dtype=dtype) + + +# ----------------------------------------------------------------------------- +# Mask helpers +# ----------------------------------------------------------------------------- + + +def test_causal_keep_mask_self_attention(): + """T_t == T_s: standard lower-triangular keep mask.""" + m = _causal_keep_mask(4, 4) + expected = jnp.tril(jnp.ones((4, 4), dtype=jnp.bool_)) + assert jnp.array_equal(m, expected) + + +def test_causal_keep_mask_cross_attention_with_prefix(): + """T_t < T_s: causal cutoff aligned to bottom-right (prefix context allowed).""" + m = _causal_keep_mask(2, 5) # T_t=2, T_s=5 → prefix of 3 always visible + expected = jnp.array( + [[True, True, True, True, False], + [True, True, True, True, True]], + dtype=jnp.bool_, + ) + assert jnp.array_equal(m, expected) + + +def test_topk_indices_to_attn_mask_basic(): + # B=1, oH=1, T_t=2, k=2 + indices = jnp.array([[[[0, 2], [1, 3]]]], dtype=jnp.int32) # [1, 1, 2, 2] + mask_out = _topk_indices_to_attn_mask(indices, T_s=4, causal=False) + expected = jnp.array( + [[[[0, 1, 0, 1], + [1, 0, 1, 0]]]], + dtype=jnp.uint8, + ) + assert mask_out.shape == (1, 1, 2, 4) + assert mask_out.dtype == jnp.uint8 + assert jnp.array_equal(mask_out, expected) + + +def test_topk_indices_to_attn_mask_per_head_diverges(): + """Different oH heads pick different topk → different per-head masks.""" + # B=1, oH=2, T_t=1, k=2 + indices = jnp.array([[[[0, 1]], [[2, 3]]]], dtype=jnp.int32) + mask_out = _topk_indices_to_attn_mask(indices, T_s=4, causal=False) + # Head 0 keeps {0,1} → mask [0,0,1,1]; head 1 keeps {2,3} → mask [1,1,0,0]. + expected = jnp.array( + [[[[0, 0, 1, 1]], + [[1, 1, 0, 0]]]], + dtype=jnp.uint8, + ) + assert mask_out.shape == (1, 2, 1, 4) + assert jnp.array_equal(mask_out, expected) + + +def test_topk_indices_to_attn_mask_causal_intersect(): + """Causal AND topk in self-attention: query t cannot keep positions > t.""" + # B=1, oH=1, T_t=T_s=4, k=2 + indices = jnp.array([[[[2, 3], [0, 1], [1, 2], [2, 3]]]], dtype=jnp.int32) + mask_out = _topk_indices_to_attn_mask(indices, T_s=4, causal=True) + # q=0: picks {2,3}, causal {0} → intersect {} → all-1 row. + assert bool((mask_out[0, 0, 0, :] == 1).all()) + # q=2: picks {1,2}, causal {0,1,2} → intersect {1,2}. + assert mask_out[0, 0, 2, 0] == 1 + assert mask_out[0, 0, 2, 1] == 0 + assert mask_out[0, 0, 2, 2] == 0 + assert mask_out[0, 0, 2, 3] == 1 + + +# ----------------------------------------------------------------------------- +# DSA composition correctness +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize("B,oH,T,hidden,D,iH,idc,idi,k", [ + (1, 4, 16, 32, 8, 2, 16, 16, 4), + (2, 4, 32, 64, 16, 2, 32, 32, 8), + (1, 2, 8, 16, 8, 1, 8, 8, 2), +]) +def test_dsa_composition_vs_pure_jax_reference(B, oH, T, hidden, D, iH, idc, idi, k): + """DSA module output (composition + hybrid indexer) matches pure-JAX reference.""" + inputs = _make_inputs(B=B, oH=oH, T=T, hidden=hidden) + keys = jax.random.split(jax.random.PRNGKey(123), 2) + module = _make_dsa_module(oH=oH, D=D, iH=iH, idc=idc, idi=idi, k=k) + params = module.init(keys[0], inputs, inputs, deterministic=True) + out = module.apply(params, inputs, inputs, deterministic=True) + assert out.shape == (B, oH, T, D) + + p = nn.meta.unbox(params)["params"] + out_ref = _ref_dsa_jax( + inputs, inputs, + p["query"]["kernel"], p["key"]["kernel"], p["value"]["kernel"], + p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], + head_dim=D, k=k, causal=True, + ) + + diff = (out.astype(jnp.float32) - out_ref.astype(jnp.float32)) + rel = float( + jnp.linalg.norm(diff) + / (jnp.linalg.norm(out_ref.astype(jnp.float32)) + 1e-30) + ) + assert rel < 5e-2, f"DSA output diverges from reference: rel.err={rel:.3e}" + + +def test_dsa_composition_reference_indexer_matches_hybrid(): + """Same correctness check using indexer_backend='reference' (pure einsum).""" + B, oH, T, hidden, D, iH, idc, idi, k = 1, 2, 8, 16, 8, 1, 8, 8, 2 + inputs = _make_inputs(B=B, oH=oH, T=T, hidden=hidden) + keys = jax.random.split(jax.random.PRNGKey(7), 2) + module = _make_dsa_module(oH=oH, D=D, iH=iH, idc=idc, idi=idi, k=k, + indexer_backend="reference") + params = module.init(keys[0], inputs, inputs, deterministic=True) + out = module.apply(params, inputs, inputs, deterministic=True) + assert out.shape == (B, oH, T, D) + + +@pytest.mark.parametrize("T_t,T_s,k", [(8, 8, 4), (8, 8, 2), (16, 16, 8)]) +def test_dsa_topk_count_equals_kept_count_under_causal(T_t, T_s, k): + """For each query t, the number of unmasked key positions equals min(k, t+1).""" + B, oH, hidden = 1, 2, 16 + inputs = _make_inputs(B=B, oH=oH, T=T_t, hidden=hidden, seed=7) + keys = jax.random.split(jax.random.PRNGKey(7), 2) + module = _make_dsa_module(oH=oH, D=8, iH=1, idc=8, idi=8, k=k) + params = module.init(keys[0], inputs, inputs, deterministic=True) + + from transformer_engine.jax.indexer import indexer as _idx + p = nn.meta.unbox(params)["params"] + scores = _idx( + inputs, inputs, + p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], + backend="reference", out_dtype=jnp.float32, + ) # [B, oH, T_t, T_s] + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] + scores_masked = jnp.where(ckeep, scores, -jnp.inf) + _, topk_idx = jax.lax.top_k(scores_masked, min(k, T_s)) + mask_out = _topk_indices_to_attn_mask(topk_idx, T_s, causal=True) + # Each (b, h, t) row should have exactly min(k, t+1) zeros. + for h in range(oH): + kept_per_q = (mask_out[0, h] == 0).sum(axis=-1) # [T_t] + for t in range(T_t): + expected = min(k, t + 1) + assert int(kept_per_q[t]) == expected, ( + f"oH={h}, t={t}: kept {int(kept_per_q[t])} keys, expected {expected}" + ) + + +# ----------------------------------------------------------------------------- +# Backward shape sanity +# ----------------------------------------------------------------------------- + + +def test_dsa_backward_runs_without_shape_errors(): + inputs = _make_inputs(B=1, oH=2, T=8, hidden=16) + keys = jax.random.split(jax.random.PRNGKey(5), 2) + module = _make_dsa_module(oH=2, D=8, iH=1, idc=8, idi=8, k=2) + params = module.init(keys[0], inputs, inputs, deterministic=True) + + def loss(p, x): + out = module.apply(p, x, x, deterministic=True) + return jnp.sum(out.astype(jnp.float32)) + + grads = jax.grad(loss)(params, inputs) + leaves = jax.tree_util.tree_leaves(grads) + assert all(bool(jnp.isfinite(leaf).all()) for leaf in leaves), \ + "DSA backward produced NaN/Inf gradients" + + +# ----------------------------------------------------------------------------- +# Scaffold contracts +# ----------------------------------------------------------------------------- + + +def test_dsa_fused_backend_raises_not_implemented(): + inputs = _make_inputs(B=1, oH=2, T=8, hidden=16) + keys = jax.random.split(jax.random.PRNGKey(0), 2) + module = _make_dsa_module(oH=2, D=8, iH=1, idc=8, idi=8, k=2, backend="fused") + # Flax materializes the call during init, so NotImplementedError fires there. + with pytest.raises(NotImplementedError, match="phase-2 scaffold"): + module.init(keys[0], inputs, inputs, deterministic=True) + + +def test_fused_sparse_attention_triton_direct_raises(): + """Calling the primitive directly also raises (locked contract).""" + q = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) # [B, T, H, D] + kk = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) + v = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) + iq = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) + ik = jnp.zeros((1, 2, 8), dtype=jnp.bfloat16) + iw = jnp.zeros((1, 2, 2), dtype=jnp.bfloat16) + with pytest.raises(NotImplementedError, match="phase-2 scaffold"): + jax.jit( + lambda *args: fused_sparse_attention_triton(*args, k=2) + )(q, kk, v, iq, ik, iw) + + +def test_hca_module_raises_not_implemented(): + module = HeavilyCompressedAttention( + head_dim=8, num_attention_heads=4, + q_lora_rank=16, kv_lora_rank=16, + qk_nope_head_dim=4, qk_rope_head_dim=4, v_head_dim=8, + ) + inputs = jax.random.normal(jax.random.PRNGKey(0), (1, 4, 32), dtype=jnp.bfloat16) + keys = jax.random.split(jax.random.PRNGKey(0), 2) + with pytest.raises(NotImplementedError, match="design.*deferred|DESIGN DEFERRED|scaffold"): + module.init(keys[0], inputs, inputs, deterministic=True) + + +def test_hca_functional_raises_not_implemented(): + inputs = jax.random.normal(jax.random.PRNGKey(0), (1, 4, 32), dtype=jnp.bfloat16) + with pytest.raises(NotImplementedError): + heavily_compressed_attention( + inputs, inputs, + head_dim=8, num_attention_heads=4, + q_lora_rank=16, kv_lora_rank=16, + qk_nope_head_dim=4, qk_rope_head_dim=4, v_head_dim=8, + ) + + +# ----------------------------------------------------------------------------- +# Functional API surface +# ----------------------------------------------------------------------------- + + +def test_deep_sparse_attention_core_invalid_backend_raises(): + q = jnp.zeros((1, 2, 4, 8)) # rank-4 + iq = jnp.zeros((1, 2, 4, 16)) + W = jnp.zeros((16, 8)) + Wuq = jnp.zeros((1, 8, 8)) + with pytest.raises(ValueError, match="unknown backend"): + deep_sparse_attention_core( + q, q, q, iq, iq, Wuq, W, W[:, :8], W[:, :1], + k=2, backend="bogus", + ) + + +def test_deep_sparse_attention_core_unsupported_mask_type_raises(): + q = jnp.zeros((1, 2, 4, 8)) + iq = jnp.zeros((1, 2, 4, 16)) + W = jnp.zeros((16, 8)) + Wuq = jnp.zeros((1, 8, 8)) + with pytest.raises(NotImplementedError, match="attn_mask_type"): + deep_sparse_attention_core( + q, q, q, iq, iq, Wuq, W, W[:, :8], W[:, :1], + k=2, attn_mask_type="padding", + ) + + +def test_deep_sparse_attention_core_rejects_rank3_inputs(): + """Rank-3 inputs (missing oH) should be rejected with a clear error.""" + q3 = jnp.zeros((1, 4, 8)) # rank-3 + iq4 = jnp.zeros((1, 2, 4, 16)) + W = jnp.zeros((16, 8)) + Wuq = jnp.zeros((1, 8, 8)) + with pytest.raises(ValueError, match="rank-4"): + deep_sparse_attention_core( + q3, q3, q3, iq4, iq4, Wuq, W, W[:, :8], W[:, :1], + k=2, + ) + + +def test_dsa_module_rejects_oh_mismatch(): + """Module asserts num_attention_heads matches inputs.shape[1].""" + inputs = _make_inputs(B=1, oH=3, T=8, hidden=16) # oH=3 in input + module = _make_dsa_module(oH=4, D=8, iH=1, idc=8, idi=8, k=2) # oH=4 in module + with pytest.raises(ValueError, match="must equal num_attention_heads"): + module.init(jax.random.PRNGKey(0), inputs, inputs, deterministic=True) diff --git a/transformer_engine/jax/compressed_attention.py b/transformer_engine/jax/compressed_attention.py new file mode 100644 index 000000000..16de17457 --- /dev/null +++ b/transformer_engine/jax/compressed_attention.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Heavily Compressed Attention (HCA) — design-deferred scaffold. + +This module stakes out the API surface for a future MLA-style (DeepSeek-V2/V3 +Multi-head Latent Attention) implementation. The Flax module and functional +entry point both raise :class:`NotImplementedError` so downstream code can +write against the eventual signature today while the design is finalized. + +The intended math (for context, not yet implemented):: + + C_q = LayerNorm(X) @ W_dq # (..., T, q_lora_rank) + Q = C_q @ W_uq # (..., T, H, qk_nope_head_dim + qk_rope_head_dim) + C_kv = X @ W_dkv # (..., S, kv_lora_rank) <-- KV cache stores this + K = C_kv @ W_uk # (..., S, H, qk_nope_head_dim + qk_rope_head_dim) + V = C_kv @ W_uv # (..., S, H, v_head_dim) + K_rope, K_nope = split(K, ...) + apply RoPE to (Q_rope, K_rope) + O = softmax(Q @ K^T / sqrt(d)) @ V + +See ``transformer_engine.jax.sparse_attention`` for the sibling DSA module +that is implemented today. +""" + +from typing import Optional + +from flax import linen as nn + +from . import indexer as _indexer # noqa: F401 — surface to assert package layout + + +_HCA_DEFER_MESSAGE = ( + "HeavilyCompressedAttention is a phase-1 scaffold (design deferred).\n" + "Open design questions to resolve before implementing:\n" + " 1. RoPE applied on compressed (C_q/C_kv) or decompressed (Q/K) tensors?\n" + " - DeepSeek-V2 applies RoPE on a separate sub-head; we should match.\n" + " 2. KV cache layout: latent-only (memory-optimal) vs latent+RoPE-sub-head?\n" + " 3. Backward through decompression: recompute (memory) vs store (bandwidth)?\n" + " 4. Should this share projection plumbing with MultiHeadAttention's " + "LayerNormDenseGeneral, or use bespoke low-rank projections?\n" + " 5. Interaction with TE's existing fused-attn backends — does any of " + "CK/AITER/cuDNN support split (RoPE/no-RoPE) head dims natively?\n" + "Pin these before filling in. See " + "transformer_engine.jax.sparse_attention for the working DSA module." +) + + +class HeavilyCompressedAttention(nn.Module): # pylint: disable=too-few-public-methods + """MLA-style heavily compressed attention — **DESIGN DEFERRED**. + + Parameters + ---------- + head_dim : int + Per-head dimension of the dense (decompressed) attention. + num_attention_heads : int + Number of attention heads. + q_lora_rank : int + Rank of the query low-rank compression (``d_c`` in indexer notation). + kv_lora_rank : int + Rank of the key/value low-rank compression. The KV cache stores + only this latent (``kv_lora_rank``-dimensional) representation. + qk_nope_head_dim : int + Per-head dimension for the non-RoPE component of Q/K. + qk_rope_head_dim : int + Per-head dimension for the RoPE component of Q/K. Total Q/K head + dim is ``qk_nope_head_dim + qk_rope_head_dim``. + v_head_dim : int + Per-head dimension of V (may differ from Q/K head dim). + attn_mask_type : str, default = ``"causal"`` + Mask type. Plumbed to the eventual dense attention call. + attention_dropout : float, default = ``0.0`` + qkv_layout : str, default = ``"bshd_bshd_bshd"`` + scale_factor : Optional[float], default = ``None`` + Defaults to ``1/sqrt(qk_nope_head_dim + qk_rope_head_dim)`` when implemented. + """ + + head_dim: int + num_attention_heads: int + q_lora_rank: int + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + attn_mask_type: str = "causal" + attention_dropout: float = 0.0 + qkv_layout: str = "bshd_bshd_bshd" + scale_factor: Optional[float] = None + + @nn.compact + def __call__(self, inputs_q, inputs_kv, *, deterministic: bool = False): # noqa: D401 + del inputs_q, inputs_kv, deterministic + raise NotImplementedError(_HCA_DEFER_MESSAGE) + + +def heavily_compressed_attention( + inputs_q, + inputs_kv, + *, + head_dim: int, + num_attention_heads: int, + q_lora_rank: int, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + attn_mask_type: str = "causal", + scale_factor: Optional[float] = None, +): + """Functional HCA — **DESIGN DEFERRED** (raises NotImplementedError). + + Mirrors the planned :class:`HeavilyCompressedAttention` surface as a + stateless function for callers that prefer functional composition. + """ + del ( + inputs_q, + inputs_kv, + head_dim, + num_attention_heads, + q_lora_rank, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + attn_mask_type, + scale_factor, + ) + raise NotImplementedError(_HCA_DEFER_MESSAGE) diff --git a/transformer_engine/jax/sparse_attention.py b/transformer_engine/jax/sparse_attention.py new file mode 100644 index 000000000..7c3337556 --- /dev/null +++ b/transformer_engine/jax/sparse_attention.py @@ -0,0 +1,542 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Deep Sparse Attention (DSA) — composes the lightning indexer with dense attention. + +Phase 1 (this file, working) composes the existing pieces: + + 1. Per-attention-head Q/K/V projection (DenseGeneral) + 2. Lightning-indexer scoring via the hybrid Triton backend + 3. Causal mask + jax.lax.top_k on each per-head score row + 4. Scatter top-k indices into a per-head sparse attention mask + 5. Call transformer_engine.jax.flax.DotProductAttention with that mask + +Phase 2 will dispatch the entire stack to a single fused Triton kernel +``transformer_engine.jax.triton_extensions.fused_sparse_attention_triton``; +the dispatch site lives in :func:`deep_sparse_attention_core` under +``backend="fused"`` and currently raises NotImplementedError (the scaffold +holds the signature stable so the kernel can land without API churn). + +**Shape contract — all DSA tensors are rank-4 with the outer-head dim +explicit:** + + inputs_q : [B, oH, T_t, hidden] + inputs_kv : [B, oH, T_s, hidden] + output : [B, oH, T_t, head_dim] + +``oH ≡ num_attention_heads``. Each attention head has its own indexer +score row, its own top-k pattern, and its own attention output. The +indexer projection *weights* are shared across attention heads — the +per-head divergence comes from the per-head input slice (the caller is +expected to have already produced per-head hidden states upstream). + +This shape contract aligns with the lightning-indexer benchmark's +``[B, oH, T, d]`` convention (see ``benchmarks/profile_indexer_topk.py``) +and lets us call the Triton hybrid backend directly without rank +adjustment. + +Zero modifications are made to upstream-tracked TE files; DSA composes +:class:`DotProductAttention` from the outside via its public ``mask=`` +argument. +""" + +from typing import Literal, Optional + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from transformer_engine.jax.flax.module import DenseGeneral +from transformer_engine.jax.flax.transformer import DotProductAttention +from transformer_engine.jax.indexer import indexer as _indexer_fn + + +# Backends supported by deep_sparse_attention_core. +_BACKENDS = ("composition", "fused") + + +# ----------------------------------------------------------------------------- +# Mask construction helpers +# ----------------------------------------------------------------------------- + + +def _causal_keep_mask(T_t: int, T_s: int, dtype=jnp.bool_): + """Lower-triangular keep mask aligned to the bottom-right corner. + + For self-attention (T_t == T_s) this is the standard ``jnp.tril(ones)``. + For cross-attention with T_t < T_s, query position ``t`` attends to key + positions ``[0, T_s - T_t + t]``. This matches the convention used by + causal cross-attention with prefix context. + """ + q_pos = jnp.arange(T_t)[:, None] # [T_t, 1] + k_pos = jnp.arange(T_s)[None, :] # [1, T_s] + keep = k_pos <= (q_pos + (T_s - T_t)) # [T_t, T_s] + return keep.astype(dtype) + + +def _topk_indices_to_attn_mask( + indices: jax.Array, + T_s: int, + *, + causal: bool, +) -> jax.Array: + """Convert per-(B, oH, T_t) top-k indices into a DPA-style mask. + + Args: + indices: ``[B, oH, T_t, k]`` int32 — top-k key positions per (B, oH, T_t). + T_s: number of key positions. + causal: if True, AND the keep-mask with a causal keep-mask before + inverting. + + Returns: + ``[B, oH, T_t, T_s]`` uint8 — ``1`` means *mask out*. The caller + reshapes to ``[B*oH, 1, T_t, T_s]`` for DPA dispatch. + """ + B, oH, T_t, _k = indices.shape + + # Scatter True at every (b, h, t, indices[b, h, t, :]) position. + keep = jnp.zeros((B, oH, T_t, T_s), dtype=jnp.bool_) + b_idx = jnp.arange(B)[:, None, None, None] # [B, 1, 1, 1] + h_idx = jnp.arange(oH)[None, :, None, None] # [1, oH, 1, 1] + t_idx = jnp.arange(T_t)[None, None, :, None] # [1, 1, T_t, 1] + # Duplicates from .at[].set(True) are idempotent — safe when k > finite scores. + keep = keep.at[b_idx, h_idx, t_idx, indices].set(True) # [B, oH, T_t, T_s] + + if causal: + keep = keep & _causal_keep_mask(T_t, T_s)[None, None, :, :] + + mask_out = jnp.logical_not(keep) + # TE's ScaledMaskedSoftmax expects uint8 mask (cpp_extensions/softmax.py:483). + return mask_out.astype(jnp.uint8) # [B, oH, T_t, T_s] + + +# ----------------------------------------------------------------------------- +# Functional API +# ----------------------------------------------------------------------------- + + +def deep_sparse_attention_core( + query: jax.Array, + key: jax.Array, + value: jax.Array, + indexer_inputs_q: jax.Array, + indexer_inputs_kv: jax.Array, + indexer_W_uq: jax.Array, + indexer_W_dq: jax.Array, + indexer_W_k: jax.Array, + indexer_W_w: jax.Array, + *, + k: int, + attn_mask_type: str = "causal", + scale_factor: Optional[float] = None, + attention_dropout: float = 0.0, + deterministic: bool = True, + backend: Literal["composition", "fused"] = "composition", + dropout_rng_name: str = "dropout", + indexer_backend: str = "hybrid", +) -> jax.Array: + """Functional DSA: indexer-top-k + per-head sparse attention. + + Args: + query, key, value: ``[B, oH, T, head_dim]`` — post-projection per-head + attention tensors. ``oH ≡ num_attention_heads``; each outer-head + slice owns a single attention head of dimension ``head_dim``. + indexer_inputs_q: ``[B, oH, T_t, hidden]`` — per-head hidden states + fed to the indexer's query side. + indexer_inputs_kv: ``[B, oH, T_s, hidden]`` — per-head hidden states + fed to the indexer's key side. + indexer_W_uq: ``[H_idx, d_c, d_i]`` indexer up-projection (shared). + indexer_W_dq: ``[hidden, d_c]`` indexer down-projection (shared). + indexer_W_k: ``[hidden, d_i]`` indexer key projection (shared). + indexer_W_w: ``[hidden, H_idx]`` indexer output-weight projection (shared). + k: number of top key positions to retain per (B, oH, T_t). + attn_mask_type: ``"causal"`` or ``"no_mask"`` (phase 1 only). + scale_factor: passed through to DPA. ``None`` → ``1/sqrt(head_dim)``. + attention_dropout, deterministic, dropout_rng_name: passed through to DPA. + backend: ``"composition"`` (working) or ``"fused"`` (phase-2 scaffold). + indexer_backend: which indexer implementation to use when + ``backend == "composition"``. ``"hybrid"`` (default, fast Triton) or + ``"reference"`` (pure einsum). + + Returns: + Attention output of the same shape as ``query``: ``[B, oH, T_t, head_dim]``. + """ + if backend not in _BACKENDS: + raise ValueError(f"unknown backend {backend!r}; expected one of {_BACKENDS}") + + if attn_mask_type not in ("causal", "no_mask"): + raise NotImplementedError( + f"deep_sparse_attention_core: attn_mask_type={attn_mask_type!r} " + "not supported in phase 1. Supported: 'causal', 'no_mask'. " + "(Padding / segment-id mask types are tracked as a follow-up.)" + ) + + if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: + raise ValueError( + f"DSA expects rank-4 query/key/value [B, oH, T, head_dim]; got " + f"shapes query={query.shape} key={key.shape} value={value.shape}" + ) + if indexer_inputs_q.ndim != 4 or indexer_inputs_kv.ndim != 4: + raise ValueError( + f"DSA expects rank-4 indexer inputs [B, oH, T, hidden]; got " + f"shapes indexer_inputs_q={indexer_inputs_q.shape} " + f"indexer_inputs_kv={indexer_inputs_kv.shape}" + ) + + if backend == "fused": + from transformer_engine.jax.triton_extensions import ( + fused_sparse_attention_triton, + ) + # Project the indexer side so the fused primitive sees Hq/Hk/W_o tensors. + # (Scaffold lowering raises; the projections are computed for shape only.) + C_q = jnp.einsum("...td,dc->...tc", indexer_inputs_q, indexer_W_dq) + Hq = jnp.einsum("...tc,hci->...thi", C_q, indexer_W_uq) + Hk = jnp.einsum("...sd,di->...si", indexer_inputs_kv, indexer_W_k) + W_o = jnp.einsum("...td,dh->...th", indexer_inputs_q, indexer_W_w) + return fused_sparse_attention_triton( + query, key, value, Hq, Hk, W_o, k=k, + ) + + # ---- composition backend ---- + B, oH, T_t, head_dim = query.shape + T_s = key.shape[2] + if key.shape != (B, oH, T_s, head_dim) or value.shape != (B, oH, T_s, head_dim): + raise ValueError( + f"DSA shape mismatch: query={query.shape} key={key.shape} value={value.shape}" + ) + + # 1. Indexer produces a per-head score row [B, oH, T_t, T_s]. + scores = _indexer_fn( + indexer_inputs_q, + indexer_inputs_kv, + indexer_W_uq, + indexer_W_dq, + indexer_W_k, + indexer_W_w, + backend=indexer_backend, + out_dtype=jnp.float32, + ) # [B, oH, T_t, T_s] fp32 + + # 2. Causal mask BEFORE top-k so non-causal positions are excluded. + causal = (attn_mask_type == "causal") + if causal: + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] # [1, 1, T_t, T_s] + scores = jnp.where(ckeep, scores, jnp.asarray(-jnp.inf, dtype=scores.dtype)) + + # 3. Per-(B, oH, T_t) top-k. + k_eff = min(k, T_s) + _, topk_idx = jax.lax.top_k(scores, k_eff) # [B, oH, T_t, k_eff] + + # 4. Scatter into [B, oH, T_t, T_s] uint8 DPA mask (1 = mask out). + sparse_mask = _topk_indices_to_attn_mask( + topk_idx, T_s, causal=causal, + ) # [B, oH, T_t, T_s] uint8 + + # 5. Dense attention with the sparse mask. We collapse (B, oH) into the + # batch dim of DPA so each attention head gets its own mask. attn_mask_type + # 'padding' tells DPA to honor the provided mask as-is (causal is baked in). + BH = B * oH + q_r = query.reshape(BH, T_t, 1, head_dim) # [BH, T_t, 1, D] + k_r = key.reshape(BH, T_s, 1, head_dim) + v_r = value.reshape(BH, T_s, 1, head_dim) + mask_r = sparse_mask.reshape(BH, 1, T_t, T_s) # [BH, 1, T_t, T_s] + + dpa = DotProductAttention( + head_dim=head_dim, + num_attention_heads=1, + num_gqa_groups=1, # one head per oH slice; must be int (probe rejects None) + attention_dropout=attention_dropout, + attn_mask_type="padding", + qkv_layout="bshd_bshd_bshd", + scale_factor=scale_factor, + dropout_rng_name=dropout_rng_name, + ) + out = dpa( + q_r, k_r, v_r, + sequence_descriptor=mask_r, + deterministic=deterministic, + ) # [BH, T_t, head_dim] (flattened H=1) + # DPA flattens the H=1 axis on output. Reshape back to [B, oH, T_t, head_dim]. + return out.reshape(B, oH, T_t, head_dim) + + +# ----------------------------------------------------------------------------- +# Flax module +# ----------------------------------------------------------------------------- + + +class DeepSparseAttention(nn.Module): # pylint: disable=too-few-public-methods + """Deep Sparse Attention (DSA) Flax module — rank-4, per-attention-head. + + Composes the lightning indexer with TE's :class:`DotProductAttention`. + Each attention head (``oH``) has its own indexer score row, top-k + pattern, and dense-attention output. Indexer projection weights are + shared across heads. + + Parameters + ---------- + head_dim : int + Per-attention-head dimension. + num_attention_heads : int + Number of attention heads (``oH``). + indexer_num_heads : int + Number of indexer-internal heads (``H`` in the indexer notation). + indexer_d_c : int + Indexer down-projection rank (``d_c``). + indexer_d_i : int + Indexer inner head dimension (``d_i``). + topk : int + Number of top key positions to retain per query. + attn_mask_type : str, default ``"causal"`` + ``"causal"`` or ``"no_mask"`` (phase 1). + attention_dropout : float, default ``0.0`` + scale_factor : Optional[float] + Defaults to ``1/sqrt(head_dim)`` inside DPA. + backend : str, default ``"composition"`` + ``"composition"`` (working) or ``"fused"`` (phase-2 scaffold). + indexer_backend : str, default ``"hybrid"`` + ``"hybrid"`` (fast Triton) or ``"reference"`` (pure einsum). Only used + when ``backend == "composition"``. + dtype : Optional[jnp.dtype] + Parameter dtype. Defaults to the input dtype. + """ + + head_dim: int + num_attention_heads: int + indexer_num_heads: int + indexer_d_c: int + indexer_d_i: int + topk: int + attn_mask_type: str = "causal" + attention_dropout: float = 0.0 + scale_factor: Optional[float] = None + backend: str = "composition" + indexer_backend: str = "hybrid" + dtype: Optional[jnp.dtype] = None + + @nn.compact + def __call__( + self, + inputs_q: jax.Array, + inputs_kv: jax.Array, + *, + deterministic: bool = True, + ) -> jax.Array: + """Run DSA on rank-4 per-head inputs. + + Args: + inputs_q: ``[B, oH, T_t, hidden]`` — per-head query-side hidden state. + inputs_kv: ``[B, oH, T_s, hidden]`` — per-head key-side hidden state. + deterministic: forwarded to DPA. + + Returns: + ``[B, oH, T_t, head_dim]`` — per-head attention output. + """ + if inputs_q.ndim != 4 or inputs_kv.ndim != 4: + raise ValueError( + f"DeepSparseAttention expects rank-4 inputs [B, oH, T, hidden]; " + f"got inputs_q.shape={inputs_q.shape}, inputs_kv.shape={inputs_kv.shape}" + ) + B, oH, T_t, hidden = inputs_q.shape + if oH != self.num_attention_heads: + raise ValueError( + f"DeepSparseAttention: inputs_q.shape[1]={oH} must equal " + f"num_attention_heads={self.num_attention_heads}" + ) + if inputs_kv.shape[0] != B or inputs_kv.shape[1] != oH or inputs_kv.shape[3] != hidden: + raise ValueError( + f"DeepSparseAttention: inputs_kv.shape={inputs_kv.shape} must match " + f"(B={B}, oH={oH}, T_s, hidden={hidden})" + ) + + param_dtype = self.dtype if self.dtype is not None else inputs_q.dtype + + # ---- per-head Q/K/V projections ---- + # DenseGeneral with features=head_dim and axis=-1 maps [..., hidden] → + # [..., head_dim], preserving the (B, oH, T) leading dims. Each attention + # head (oH slice) shares the projection kernel — divergence comes from the + # per-head input slice the caller provides. + query = DenseGeneral( + features=self.head_dim, + use_bias=False, + dtype=param_dtype, + name="query", + )(inputs_q) # [B, oH, T_t, head_dim] + key = DenseGeneral( + features=self.head_dim, + use_bias=False, + dtype=param_dtype, + name="key", + )(inputs_kv) # [B, oH, T_s, head_dim] + value = DenseGeneral( + features=self.head_dim, + use_bias=False, + dtype=param_dtype, + name="value", + )(inputs_kv) # [B, oH, T_s, head_dim] + + # ---- indexer projections (shared across oH) ---- + # Shapes mirror transformer_engine.jax.indexer:31-48. + W_dq = self.param( + "indexer_W_dq", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (hidden, self.indexer_d_c), + param_dtype, + ) + W_uq = self.param( + "indexer_W_uq", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (self.indexer_num_heads, self.indexer_d_c, self.indexer_d_i), + param_dtype, + ) + W_k_idx = self.param( + "indexer_W_k", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (hidden, self.indexer_d_i), + param_dtype, + ) + W_w = self.param( + "indexer_W_w", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (hidden, self.indexer_num_heads), + param_dtype, + ) + + return deep_sparse_attention_core( + query, key, value, + inputs_q, inputs_kv, + W_uq, W_dq, W_k_idx, W_w, + k=self.topk, + attn_mask_type=self.attn_mask_type, + scale_factor=self.scale_factor, + attention_dropout=self.attention_dropout, + deterministic=deterministic, + backend=self.backend, + indexer_backend=self.indexer_backend, + ) # [B, oH, T_t, head_dim] + + +# ----------------------------------------------------------------------------- +# Reference / smoke test +# ----------------------------------------------------------------------------- + + +def _ref_dense_softmax_per_head( + query, key, value, mask_out, scale, +): + """Reference: per-head dense softmax attention with arbitrary mask (no DPA). + + Args: + query, key, value: ``[B, oH, T, head_dim]`` + mask_out: ``[B, oH, T_t, T_s]`` uint8 (1 = mask out) + scale: scalar + Returns: ``[B, oH, T_t, head_dim]`` + """ + logits = jnp.einsum("bhtd,bhsd->bhts", query, key) * scale # [B, oH, T_t, T_s] + logits = logits.astype(jnp.float32) + logits = jnp.where( + mask_out.astype(jnp.bool_), + jnp.asarray(-jnp.inf, jnp.float32), + logits, + ) + weights = jax.nn.softmax(logits, axis=-1) + out = jnp.einsum("bhts,bhsd->bhtd", weights.astype(value.dtype), value) + return out + + +def _ref_dsa_jax( + inputs_q, inputs_kv, + W_q_kernel, W_k_kernel, W_v_kernel, + W_uq, W_dq, W_k_idx, W_w, + *, + head_dim, k, causal, +): + """Pure-JAX reference matching ``deep_sparse_attention_core``.""" + # inputs_q: [B, oH, T_t, hidden]; inputs_kv: [B, oH, T_s, hidden] + B, oH, T_t, hidden = inputs_q.shape + T_s = inputs_kv.shape[2] + + # Per-head Q/K/V projections (kernel shape [hidden, head_dim], shared across oH). + q = jnp.einsum("bhtd,dk->bhtk", inputs_q, W_q_kernel) # [B, oH, T_t, D] + kk = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_k_kernel) + v = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_v_kernel) + + scores = _indexer_fn( + inputs_q, inputs_kv, W_uq, W_dq, W_k_idx, W_w, + backend="reference", out_dtype=jnp.float32, + ) # [B, oH, T_t, T_s] + if causal: + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] + scores = jnp.where(ckeep, scores, jnp.asarray(-jnp.inf, jnp.float32)) + _, topk_idx = jax.lax.top_k(scores, min(k, T_s)) + mask_out = _topk_indices_to_attn_mask(topk_idx, T_s, causal=causal) + out = _ref_dense_softmax_per_head( + q, kk, v, mask_out, scale=1.0 / jnp.sqrt(head_dim).astype(q.dtype), + ) + return out + + +def _smoke_test(seed=0): + """Self-attention smoke test: DSA composition vs hand-rolled JAX reference.""" + B, oH, T, hidden = 1, 4, 16, 32 + head_dim = 8 + iH, idc, idi = 2, 16, 16 + k = 4 + keys = jax.random.split(jax.random.PRNGKey(seed), 2) + inputs = jax.random.normal(keys[0], (B, oH, T, hidden), dtype=jnp.bfloat16) + + module = DeepSparseAttention( + head_dim=head_dim, + num_attention_heads=oH, + indexer_num_heads=iH, + indexer_d_c=idc, + indexer_d_i=idi, + topk=k, + dtype=jnp.bfloat16, + ) + params = module.init(keys[1], inputs, inputs, deterministic=True) + out = module.apply(params, inputs, inputs, deterministic=True) + print(f" DSA composition out.shape = {out.shape} dtype = {out.dtype} [OK]") + assert out.shape == (B, oH, T, head_dim), f"Unexpected shape {out.shape}" + + params_unboxed = nn.meta.unbox(params) + p = params_unboxed["params"] + out_ref = _ref_dsa_jax( + inputs, inputs, + p["query"]["kernel"], p["key"]["kernel"], p["value"]["kernel"], + p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], + head_dim=head_dim, k=k, causal=True, + ) + + diff = (out.astype(jnp.float32) - out_ref.astype(jnp.float32)) + rel = float(jnp.linalg.norm(diff) / + (jnp.linalg.norm(out_ref.astype(jnp.float32)) + 1e-30)) + tag = "OK" if rel < 5e-2 else "FAIL" + print(f" DSA vs hand-rolled JAX reference: rel.err = {rel:.2e} [{tag}]") + + +def _scaffold_test(): + """Confirm backend='fused' raises NotImplementedError (scaffold contract).""" + B, oH, T, hidden = 1, 2, 8, 16 + head_dim = 8 + iH, idc, idi = 2, 8, 8 + keys = jax.random.split(jax.random.PRNGKey(42), 2) + inputs = jax.random.normal(keys[0], (B, oH, T, hidden), dtype=jnp.bfloat16) + module = DeepSparseAttention( + head_dim=head_dim, num_attention_heads=oH, + indexer_num_heads=iH, indexer_d_c=idc, indexer_d_i=idi, + topk=4, backend="fused", + ) + try: + module.init(keys[1], inputs, inputs, deterministic=True) + print(" FAIL: fused backend should have raised NotImplementedError") + except NotImplementedError as e: + msg = str(e).splitlines()[0] + print(f" OK: fused backend raises NotImplementedError ({msg!r})") + + +if __name__ == "__main__": + print("=== DSA composition smoke test (hybrid indexer) ===") + _smoke_test(seed=0) + print("\n=== DSA fused-backend scaffold contract ===") + _scaffold_test() diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index b153b6f18..79ccc0f73 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -59,3 +59,4 @@ def lowering(ctx, x, **kwargs): from .utils import * from .permutation import * from .indexer import score_reduce_triton, score_topk_triton +from .sparse_attention import fused_sparse_attention_triton diff --git a/transformer_engine/jax/triton_extensions/sparse_attention.py b/transformer_engine/jax/triton_extensions/sparse_attention.py new file mode 100644 index 000000000..28b784ac1 --- /dev/null +++ b/transformer_engine/jax/triton_extensions/sparse_attention.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Phase-2 scaffold for the fused sparse-attention Triton primitive. + +The functional API `fused_sparse_attention_triton(query, key, value, +indexer_query, indexer_key, indexer_weights, *, k, ...)` is declared and +registered as a JAX primitive with abstract evaluation, but the kernel +body and MLIR lowering both raise NotImplementedError. + +Purpose: lock the call signature so the DSA Flax module can dispatch to +this primitive via ``backend="fused"`` today, and the real kernel can +land later without any caller-side changes. + +The composition path in ``transformer_engine.jax.sparse_attention`` +(indexer + sparse mask + DotProductAttention) is the supported phase-1 +implementation. +""" + +import functools + +import jax.numpy as jnp + +from jax import core +from jax.extend import core as extend_core +from jax.interpreters import mlir, xla + + +_fused_sparse_attention_p = extend_core.Primitive("te_fused_sparse_attention_triton") +_fused_sparse_attention_p.multiple_results = False + + +@_fused_sparse_attention_p.def_abstract_eval +def _fused_sparse_attention_abstract( + query, key, value, indexer_query, indexer_key, indexer_weights, *, k +): + """Output has the same shape/dtype as ``query`` (BSHD layout assumed).""" + del key, value, indexer_query, indexer_key, indexer_weights, k + return core.ShapedArray(query.shape, query.dtype) + + +_fused_sparse_attention_p.def_impl( + functools.partial(xla.apply_primitive, _fused_sparse_attention_p) +) + + +def _fused_sparse_attention_lowering_unavailable(ctx, *args, **kwargs): + raise NotImplementedError( + "fused_sparse_attention_triton is a phase-2 scaffold: the Triton kernel " + "has not been implemented yet. Use backend='composition' in " + "transformer_engine.jax.sparse_attention.deep_sparse_attention_core(...) " + "for the working composition path." + ) + + +mlir.register_lowering( + _fused_sparse_attention_p, + _fused_sparse_attention_lowering_unavailable, + platform="rocm", +) +mlir.register_lowering( + _fused_sparse_attention_p, + _fused_sparse_attention_lowering_unavailable, + platform="cuda", +) + + +def fused_sparse_attention_triton( + query, + key, + value, + indexer_query, + indexer_key, + indexer_weights, + *, + k: int, +): + """Fused indexer + sparse attention (phase-2 scaffold — raises NotImplementedError). + + Intended contract for the future fused kernel: + + Args: + query: (B, T_t, H, D) attention queries (BSHD) + key: (B, T_s, H_kv, D) attention keys + value: (B, T_s, H_kv, D) attention values + indexer_query: (B, T_t, H_idx, d_i) post-projection indexer Hq + indexer_key: (B, T_s, d_i) post-projection indexer Hk + indexer_weights: (B, T_t, H_idx) post-projection indexer W_o + k: number of top-k key positions per query token + + Returns: + Output of shape (B, T_t, H, D) — sparse attention output where each + query attends only to its indexer-selected top-k key positions + (intersected with the causal mask). + + The signature is intentionally minimal so phase-2 has room to grow it + (e.g. window_size, attn_bias). Add kwargs only via the function + signature — abstract_eval and lowering both already accept ``**kwargs``. + """ + return _fused_sparse_attention_p.bind( + query, + key, + value, + indexer_query, + indexer_key, + indexer_weights, + k=k, + ) From 7c01b2bfb098b38ce65ef54c9a4985f0427cff33 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 29 May 2026 19:27:57 +0000 Subject: [PATCH 10/17] Trimmed and streamlined --- tests/jax/test_indexer.py | 87 +++++++++++ tests/jax/test_sparse_attention.py | 51 ++++++- transformer_engine/jax/__init__.py | 8 ++ transformer_engine/jax/indexer.py | 136 +++--------------- transformer_engine/jax/sparse_attention.py | 136 +----------------- .../jax/triton_extensions/indexer.py | 3 +- .../jax/triton_extensions/utils.py | 3 +- 7 files changed, 170 insertions(+), 254 deletions(-) create mode 100644 tests/jax/test_indexer.py diff --git a/tests/jax/test_indexer.py b/tests/jax/test_indexer.py new file mode 100644 index 000000000..c04d92a2d --- /dev/null +++ b/tests/jax/test_indexer.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Correctness tests for the lightning-indexer JAX ops. + +Ported from the in-module ``__main__`` smoke tests of +``transformer_engine.jax.indexer``. The hybrid and top-k backends require +rank-4 ``(B, oH, T, d)`` inputs, so every leading shape here is length-2. +""" + +import jax +import jax.numpy as jnp +import pytest + +from transformer_engine.jax.indexer import indexer, indexer_topk + + +def _indexer_inputs(B, oH, T_t, T_s, d, d_c, H, d_i, seed): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (B, oH, T_s, d), dtype=jnp.bfloat16) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) + W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) + return Q, K, W_uq, W_dq, W_k, W_w + + +def _rel_err(actual, ref): + actual = actual.astype(jnp.float32) + ref = ref.astype(jnp.float32) + return float(jnp.linalg.norm(actual - ref) / (jnp.linalg.norm(ref) + 1e-30)) + + +@pytest.mark.parametrize("B,oH", [(2, 3), (1, 1), (1, 4)]) +def test_hybrid_matches_reference(B, oH): + """Hybrid Triton score-reduce matches the pure-einsum reference forward.""" + args = _indexer_inputs(B, oH, T_t=64, T_s=64, d=32, d_c=32, H=8, d_i=32, seed=100) + o_ref = indexer(*args, backend="reference") + o_hyb = indexer(*args, backend="hybrid") + assert o_hyb.shape == o_ref.shape + assert _rel_err(o_hyb, o_ref) < 5e-3 + + +@pytest.mark.parametrize("k", [32]) +def test_topk_matches_reference(k): + """Fused top-k selects the same scores as reference + ``jax.lax.top_k``. + + Index set-equality is too strict (backends break ties differently), so the + check is on the *scores* at the fused-selected indices. ``k`` is kept in the + top quartile of ``T_s``: a cutoff in the dense middle of the distribution + makes boundary scores closely spaced, so the kernel's fp32 ranking and the + bf16-rounded reference grid resolve near-ties differently (a test-grid + sensitivity, not a kernel error). + """ + args = _indexer_inputs(2, 3, T_t=64, T_s=128, d=32, d_c=32, H=16, d_i=32, seed=200) + o_ref = indexer(*args, backend="reference").astype(jnp.float32) + topk_idx = indexer_topk(*args, k=k) + assert topk_idx.shape == (2, 3, 64, k) + + ref_vals = jax.lax.top_k(o_ref, k=k)[0] + picked = jnp.take_along_axis(o_ref, topk_idx, axis=-1) + picked_sorted = jnp.sort(picked, axis=-1)[..., ::-1] + max_rel = float((jnp.abs(ref_vals - picked_sorted) / (jnp.abs(ref_vals) + 1e-6)).max()) + assert max_rel < 1e-2 + + +@pytest.mark.parametrize("B,oH", [(2, 3), (1, 2)]) +def test_hybrid_backward_matches_reference_grad(B, oH): + """``jax.grad`` through the hybrid backend matches grad through reference. + + Tolerance is 5e-2 (bf16 projections + Triton score recompute) — looser than + the 5e-3 forward tolerance; tighten once per-grad error is characterized + on-device. + """ + args = _indexer_inputs(B, oH, T_t=32, T_s=32, d=32, d_c=32, H=8, d_i=32, seed=300) + + def _loss(backend): + def inner(*a): + return jnp.sum(indexer(*a, backend=backend).astype(jnp.float32)) + return inner + + argnums = (0, 1, 2, 3, 4, 5) + grads_ref = jax.grad(_loss("reference"), argnums=argnums)(*args) + grads_hyb = jax.grad(_loss("hybrid"), argnums=argnums)(*args) + for gr, gh in zip(grads_ref, grads_hyb): + assert _rel_err(gh, gr) < 5e-2 diff --git a/tests/jax/test_sparse_attention.py b/tests/jax/test_sparse_attention.py index 38b65430b..d486f24c8 100644 --- a/tests/jax/test_sparse_attention.py +++ b/tests/jax/test_sparse_attention.py @@ -13,12 +13,12 @@ deep_sparse_attention_core, _causal_keep_mask, _topk_indices_to_attn_mask, - _ref_dsa_jax, ) from transformer_engine.jax.compressed_attention import ( HeavilyCompressedAttention, heavily_compressed_attention, ) +from transformer_engine.jax.indexer import indexer from transformer_engine.jax.triton_extensions import fused_sparse_attention_triton @@ -61,6 +61,52 @@ def _make_inputs(B=1, oH=4, T=16, hidden=32, dtype=jnp.bfloat16, seed=0): return jax.random.normal(jax.random.PRNGKey(seed), (B, oH, T, hidden), dtype=dtype) +def _ref_dense_softmax_per_head(query, key, value, mask_out, scale): + """Per-head dense softmax attention with an arbitrary mask (no DPA). + + query/key/value: [B, oH, T, head_dim]; mask_out: [B, oH, T_t, T_s] uint8 + (1 = mask out). Returns [B, oH, T_t, head_dim]. + """ + logits = jnp.einsum("bhtd,bhsd->bhts", query, key) * scale + logits = logits.astype(jnp.float32) + logits = jnp.where( + mask_out.astype(jnp.bool_), + jnp.asarray(-jnp.inf, jnp.float32), + logits, + ) + weights = jax.nn.softmax(logits, axis=-1) + return jnp.einsum("bhts,bhsd->bhtd", weights.astype(value.dtype), value) + + +def _ref_dsa_jax( + inputs_q, inputs_kv, + W_q_kernel, W_k_kernel, W_v_kernel, + W_uq, W_dq, W_k_idx, W_w, + *, + head_dim, k, causal, +): + """Pure-JAX reference matching ``deep_sparse_attention_core``.""" + T_t = inputs_q.shape[2] + T_s = inputs_kv.shape[2] + + q = jnp.einsum("bhtd,dk->bhtk", inputs_q, W_q_kernel) + kk = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_k_kernel) + v = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_v_kernel) + + scores = indexer( + inputs_q, inputs_kv, W_uq, W_dq, W_k_idx, W_w, + backend="reference", out_dtype=jnp.float32, + ) + if causal: + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] + scores = jnp.where(ckeep, scores, jnp.asarray(-jnp.inf, jnp.float32)) + _, topk_idx = jax.lax.top_k(scores, min(k, T_s)) + mask_out = _topk_indices_to_attn_mask(topk_idx, T_s, causal=causal) + return _ref_dense_softmax_per_head( + q, kk, v, mask_out, scale=1.0 / jnp.sqrt(head_dim).astype(q.dtype), + ) + + # ----------------------------------------------------------------------------- # Mask helpers # ----------------------------------------------------------------------------- @@ -183,9 +229,8 @@ def test_dsa_topk_count_equals_kept_count_under_causal(T_t, T_s, k): module = _make_dsa_module(oH=oH, D=8, iH=1, idc=8, idi=8, k=k) params = module.init(keys[0], inputs, inputs, deterministic=True) - from transformer_engine.jax.indexer import indexer as _idx p = nn.meta.unbox(params)["params"] - scores = _idx( + scores = indexer( inputs, inputs, p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], backend="reference", out_dtype=jnp.float32, diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index d0afc1ff2..bd2bf44f2 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,6 +34,11 @@ from . import flax from . import quantize +# AMD lightning-indexer / sparse-attention staging modules. +from . import indexer +from . import sparse_attention +from . import compressed_attention + from .quantize import autocast, fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME @@ -51,4 +56,7 @@ "MeshResource", "flax", "quantize", + "indexer", + "sparse_attention", + "compressed_attention", ] diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/indexer.py index 347e5ead6..a01bb60b9 100644 --- a/transformer_engine/jax/indexer.py +++ b/transformer_engine/jax/indexer.py @@ -1,3 +1,6 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. """Indexer op (forward only), bf16 inputs. Two canonical backends: @@ -28,6 +31,19 @@ import jax.numpy as jnp +def _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w): + """Low-rank indexer projections shared by every backend. + + Returns (H_q, H_k, W_o) with shapes + (..., T, H, d_i), (..., S, d_i), (..., T, H). + """ + C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) + H_k = jnp.einsum("...sd,di->...si", K, W_k) + W_o = jnp.einsum("...td,dh->...th", Q, W_w) + return H_q, H_k, W_o + + def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): """ Q [..., T, d] @@ -37,12 +53,9 @@ def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): W_k [d, d_i] W_w [..., d, H] # leading dims must match Q's """ - C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) # (..., T, d_c) - H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) - H_k = jnp.einsum("...sd,di->...si", K, W_k) # (..., S, d_i) + H_q, H_k, W_o = _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w) H = jax.nn.relu(jnp.einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) - W_o = jnp.einsum("...td,dh->...th", Q, W_w) - O = jnp.einsum("...ths,...th->...ts", H, W_o) # (..., T, S) + O = jnp.einsum("...ths,...th->...ts", H, W_o) # (..., T, S) if out_dtype is not None: O = O.astype(out_dtype) return O @@ -59,11 +72,7 @@ def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): """ from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton - C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) # (..., T, d_c) - H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) - H_k = jnp.einsum("...sd,di->...si", K, W_k) # (..., S, d_i) - W_o = jnp.einsum("...td,dh->...th", Q, W_w) # (..., T, H) - + H_q, H_k, W_o = _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w) return score_reduce_triton(H_q, H_k, W_o, out_dtype=out_dtype if out_dtype else Q.dtype) @@ -87,10 +96,7 @@ def indexer_topk(Q, K, W_uq, W_dq, W_k, weights, *, k): in descending score order. """ from transformer_engine.jax.triton_extensions.indexer import score_topk_triton - C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) # (..., T, d_c) - H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) - H_k = jnp.einsum("...sd,di->...si", K, W_k) # (..., S, d_i) - W_o = jnp.einsum("...td,dh->...th", Q, weights) # (..., T, H) + H_q, H_k, W_o = _indexer_projections(Q, K, W_uq, W_dq, W_k, weights) return score_topk_triton(H_q, H_k, W_o, k=k) @@ -120,105 +126,3 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, out_dtype=None, backend="referenc raise ValueError( f"unknown backend {backend!r}; expected 'reference' or 'hybrid'" ) - - -# --- Tests ---------------------------------------------------------------------- - -def _run_test(leading_shape, seed, backend): - # The hybrid backend's Triton kernel requires rank-4 BHSD inputs. - T_t, T_s, d, d_c, H, d_i = 64, 64, 32, 32, 8, 32 - keys = jax.random.split(jax.random.PRNGKey(seed), 6) - Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) - K = jax.random.normal(keys[1], (*leading_shape, T_s, d), dtype=jnp.bfloat16) - W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) - W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) - W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) - W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) - - O_ref = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") - O_b = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) - - diff = (O_ref.astype(jnp.float32) - O_b.astype(jnp.float32)) - rel_err = float(jnp.linalg.norm(diff) / - (jnp.linalg.norm(O_ref.astype(jnp.float32)) + 1e-30)) - tag = "OK" if rel_err < 5e-3 else "FAIL" - print(f" backend={backend:<10s} leading={str(leading_shape):10s} " - f"O.shape={O_b.shape} rel.err={rel_err:.2e} [{tag}]") - - -def _run_topk_test(leading_shape, seed, k): - # H=16 to keep the matmul in [BLOCK_S, H] friendly to MFMA tile sizes. - T_t, T_s, d, d_c, H, d_i = 64, 128, 32, 32, 16, 32 - keys = jax.random.split(jax.random.PRNGKey(seed), 6) - Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) - K = jax.random.normal(keys[1], (*leading_shape, T_s, d), dtype=jnp.bfloat16) - W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) - W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) - W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) - W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) - - O_ref = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") - topk_fused = indexer_topk(Q, K, W_uq, W_dq, W_k, W_w, k=k) - - # Correctness check: the scores at the fused-picked indices should equal the - # top-k scores from the reference (within bf16 noise). Set-equality of indices - # is too strict — different backends break ties differently. - O_ref32 = O_ref.astype(jnp.float32) - ref_topk_vals = jax.lax.top_k(O_ref32, k=k)[0] # [..., T_t, k] sorted desc - fused_picked_vals = jnp.take_along_axis(O_ref32, topk_fused, axis=-1) - fused_picked_sorted = jnp.sort(fused_picked_vals, axis=-1)[..., ::-1] - rel_diff = jnp.abs(ref_topk_vals - fused_picked_sorted) / (jnp.abs(ref_topk_vals) + 1e-6) - max_rel = float(rel_diff.max()) - tag = "OK" if max_rel < 1e-2 else f"FAIL (max_rel={max_rel:.2e})" - print(f" topk leading={str(leading_shape):10s} k={k:<4d} " - f"out.shape={topk_fused.shape} max_rel={max_rel:.2e} [{tag}]") - - -def _run_bwd_test(leading_shape, seed): - """Compare hybrid backward against jax.grad on the reference impl.""" - T_t, T_s, d, d_c, H, d_i = 32, 32, 32, 32, 8, 32 - keys = jax.random.split(jax.random.PRNGKey(seed), 6) - Q = jax.random.normal(keys[0], (*leading_shape, T_t, d), dtype=jnp.bfloat16) - K = jax.random.normal(keys[1], (*leading_shape, T_s, d), dtype=jnp.bfloat16) - W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) - W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) - W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) - W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) - - def loss_ref(Q, K, W_uq, W_dq, W_k, W_w): - O = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="reference") - return jnp.sum(O.astype(jnp.float32)) - - def loss_hyb(Q, K, W_uq, W_dq, W_k, W_w): - O = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend="hybrid") - return jnp.sum(O.astype(jnp.float32)) - - grads_ref = jax.grad(loss_ref, argnums=(0, 1, 2, 3, 4, 5))(Q, K, W_uq, W_dq, W_k, W_w) - grads_hyb = jax.grad(loss_hyb, argnums=(0, 1, 2, 3, 4, 5))(Q, K, W_uq, W_dq, W_k, W_w) - - names = ("dQ", "dK", "dW_uq", "dW_dq", "dW_k", "dW_w") - all_ok = True - for name, gr, gh in zip(names, grads_ref, grads_hyb): - diff = (gr.astype(jnp.float32) - gh.astype(jnp.float32)) - rel = float(jnp.linalg.norm(diff) / - (jnp.linalg.norm(gr.astype(jnp.float32)) + 1e-30)) - ok = rel < 5e-2 - all_ok = all_ok and ok - tag = "OK" if ok else "FAIL" - print(f" {name:<6} shape={str(gh.shape):<22s} rel.err={rel:.2e} [{tag}]") - overall = "OK" if all_ok else "FAIL" - print(f" bwd leading={str(leading_shape):10s} overall=[{overall}]") - - -if __name__ == "__main__": - print("=== reference vs reference (sanity) ===") - _run_test((2, 3), seed=0, backend="reference") - - print("\n=== hybrid vs reference ===") - _run_test((2, 3), seed=100, backend="hybrid") - - print("\n=== indexer_topk vs reference + jax.lax.top_k ===") - _run_topk_test((2, 3), seed=200, k=32) - - print("\n=== backward: hybrid vs jax.grad(reference) ===") - _run_bwd_test((2, 3), seed=300) diff --git a/transformer_engine/jax/sparse_attention.py b/transformer_engine/jax/sparse_attention.py index 7c3337556..3a7df3901 100644 --- a/transformer_engine/jax/sparse_attention.py +++ b/transformer_engine/jax/sparse_attention.py @@ -48,7 +48,7 @@ from transformer_engine.jax.flax.module import DenseGeneral from transformer_engine.jax.flax.transformer import DotProductAttention -from transformer_engine.jax.indexer import indexer as _indexer_fn +from transformer_engine.jax.indexer import indexer as _indexer_fn, _indexer_projections # Backends supported by deep_sparse_attention_core. @@ -189,10 +189,10 @@ def deep_sparse_attention_core( ) # Project the indexer side so the fused primitive sees Hq/Hk/W_o tensors. # (Scaffold lowering raises; the projections are computed for shape only.) - C_q = jnp.einsum("...td,dc->...tc", indexer_inputs_q, indexer_W_dq) - Hq = jnp.einsum("...tc,hci->...thi", C_q, indexer_W_uq) - Hk = jnp.einsum("...sd,di->...si", indexer_inputs_kv, indexer_W_k) - W_o = jnp.einsum("...td,dh->...th", indexer_inputs_q, indexer_W_w) + Hq, Hk, W_o = _indexer_projections( + indexer_inputs_q, indexer_inputs_kv, + indexer_W_uq, indexer_W_dq, indexer_W_k, indexer_W_w, + ) return fused_sparse_attention_triton( query, key, value, Hq, Hk, W_o, k=k, ) @@ -414,129 +414,3 @@ def __call__( backend=self.backend, indexer_backend=self.indexer_backend, ) # [B, oH, T_t, head_dim] - - -# ----------------------------------------------------------------------------- -# Reference / smoke test -# ----------------------------------------------------------------------------- - - -def _ref_dense_softmax_per_head( - query, key, value, mask_out, scale, -): - """Reference: per-head dense softmax attention with arbitrary mask (no DPA). - - Args: - query, key, value: ``[B, oH, T, head_dim]`` - mask_out: ``[B, oH, T_t, T_s]`` uint8 (1 = mask out) - scale: scalar - Returns: ``[B, oH, T_t, head_dim]`` - """ - logits = jnp.einsum("bhtd,bhsd->bhts", query, key) * scale # [B, oH, T_t, T_s] - logits = logits.astype(jnp.float32) - logits = jnp.where( - mask_out.astype(jnp.bool_), - jnp.asarray(-jnp.inf, jnp.float32), - logits, - ) - weights = jax.nn.softmax(logits, axis=-1) - out = jnp.einsum("bhts,bhsd->bhtd", weights.astype(value.dtype), value) - return out - - -def _ref_dsa_jax( - inputs_q, inputs_kv, - W_q_kernel, W_k_kernel, W_v_kernel, - W_uq, W_dq, W_k_idx, W_w, - *, - head_dim, k, causal, -): - """Pure-JAX reference matching ``deep_sparse_attention_core``.""" - # inputs_q: [B, oH, T_t, hidden]; inputs_kv: [B, oH, T_s, hidden] - B, oH, T_t, hidden = inputs_q.shape - T_s = inputs_kv.shape[2] - - # Per-head Q/K/V projections (kernel shape [hidden, head_dim], shared across oH). - q = jnp.einsum("bhtd,dk->bhtk", inputs_q, W_q_kernel) # [B, oH, T_t, D] - kk = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_k_kernel) - v = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_v_kernel) - - scores = _indexer_fn( - inputs_q, inputs_kv, W_uq, W_dq, W_k_idx, W_w, - backend="reference", out_dtype=jnp.float32, - ) # [B, oH, T_t, T_s] - if causal: - ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] - scores = jnp.where(ckeep, scores, jnp.asarray(-jnp.inf, jnp.float32)) - _, topk_idx = jax.lax.top_k(scores, min(k, T_s)) - mask_out = _topk_indices_to_attn_mask(topk_idx, T_s, causal=causal) - out = _ref_dense_softmax_per_head( - q, kk, v, mask_out, scale=1.0 / jnp.sqrt(head_dim).astype(q.dtype), - ) - return out - - -def _smoke_test(seed=0): - """Self-attention smoke test: DSA composition vs hand-rolled JAX reference.""" - B, oH, T, hidden = 1, 4, 16, 32 - head_dim = 8 - iH, idc, idi = 2, 16, 16 - k = 4 - keys = jax.random.split(jax.random.PRNGKey(seed), 2) - inputs = jax.random.normal(keys[0], (B, oH, T, hidden), dtype=jnp.bfloat16) - - module = DeepSparseAttention( - head_dim=head_dim, - num_attention_heads=oH, - indexer_num_heads=iH, - indexer_d_c=idc, - indexer_d_i=idi, - topk=k, - dtype=jnp.bfloat16, - ) - params = module.init(keys[1], inputs, inputs, deterministic=True) - out = module.apply(params, inputs, inputs, deterministic=True) - print(f" DSA composition out.shape = {out.shape} dtype = {out.dtype} [OK]") - assert out.shape == (B, oH, T, head_dim), f"Unexpected shape {out.shape}" - - params_unboxed = nn.meta.unbox(params) - p = params_unboxed["params"] - out_ref = _ref_dsa_jax( - inputs, inputs, - p["query"]["kernel"], p["key"]["kernel"], p["value"]["kernel"], - p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], - head_dim=head_dim, k=k, causal=True, - ) - - diff = (out.astype(jnp.float32) - out_ref.astype(jnp.float32)) - rel = float(jnp.linalg.norm(diff) / - (jnp.linalg.norm(out_ref.astype(jnp.float32)) + 1e-30)) - tag = "OK" if rel < 5e-2 else "FAIL" - print(f" DSA vs hand-rolled JAX reference: rel.err = {rel:.2e} [{tag}]") - - -def _scaffold_test(): - """Confirm backend='fused' raises NotImplementedError (scaffold contract).""" - B, oH, T, hidden = 1, 2, 8, 16 - head_dim = 8 - iH, idc, idi = 2, 8, 8 - keys = jax.random.split(jax.random.PRNGKey(42), 2) - inputs = jax.random.normal(keys[0], (B, oH, T, hidden), dtype=jnp.bfloat16) - module = DeepSparseAttention( - head_dim=head_dim, num_attention_heads=oH, - indexer_num_heads=iH, indexer_d_c=idc, indexer_d_i=idi, - topk=4, backend="fused", - ) - try: - module.init(keys[1], inputs, inputs, deterministic=True) - print(" FAIL: fused backend should have raised NotImplementedError") - except NotImplementedError as e: - msg = str(e).splitlines()[0] - print(f" OK: fused backend raises NotImplementedError ({msg!r})") - - -if __name__ == "__main__": - print("=== DSA composition smoke test (hybrid indexer) ===") - _smoke_test(seed=0) - print("\n=== DSA fused-backend scaffold contract ===") - _scaffold_test() diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index 2c012b685..e6477d6ae 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -1,5 +1,4 @@ # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Triton score-relu-reduce kernel for the lightning-indexer hybrid backend. @@ -384,7 +383,7 @@ def _score_reduce_bwd(out_dtype, residuals, dO): H_CHUNK = _BWD_H_CHUNK else: H_CHUNK = 1 - for c in (8, 4, 2): + for c in (4, 2): if H % c == 0: H_CHUNK = c break diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 219b286b6..69d939020 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -545,8 +545,7 @@ def _normalize_grid(g): kernel_calls = [] actual_kernel_fn = kernel_fn.fn - for idx, config in enumerate(kernel_fn.configs): - print(f"DEBUG *** Running config {idx+1}/{len(kernel_fn.configs)}") + for config in kernel_fn.configs: # Extract parameters from config config_num_warps = config.num_warps if config.num_warps is not None else num_warps config_num_stages = config.num_stages if config.num_stages is not None else num_stages From 31b9a8d3add5a063c20ec844c037007e1e3da5f3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 1 Jun 2026 20:50:45 +0000 Subject: [PATCH 11/17] Added benchmarks for indexer --- benchmarks/profile_indexer_bwd.py | 111 ++++++++++++++++++++--------- benchmarks/profile_indexer_topk.py | 2 +- 2 files changed, 80 insertions(+), 33 deletions(-) diff --git a/benchmarks/profile_indexer_bwd.py b/benchmarks/profile_indexer_bwd.py index 070c6a7aa..85363141f 100644 --- a/benchmarks/profile_indexer_bwd.py +++ b/benchmarks/profile_indexer_bwd.py @@ -6,8 +6,13 @@ Run inside the container: docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer_bwd.py' + +Select backends and passes via flags: + --backends reference hybrid + --passes fwd bwd vag """ +import argparse import time import jax @@ -23,6 +28,10 @@ _HYBRID_IMPORT_ERROR = _e +ALL_BACKENDS = ["reference", "hybrid"] +ALL_PASSES = ["fwd", "bwd", "vag"] + + def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): keys = jax.random.split(jax.random.PRNGKey(seed), 6) Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) @@ -82,52 +91,90 @@ def _build_value_and_grad(backend): return jax.jit(jax.value_and_grad(fwd, argnums=(0, 1, 2, 3, 4, 5))) +PASS_SPECS = { + "fwd": ("forward", _build_fwd, 1), + "bwd": ("backward", _build_bwd, 2), + "vag": ("value_and_grad", _build_value_and_grad, 3), +} + + +def parse_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--backends", + nargs="+", + choices=ALL_BACKENDS, + default=None, + help=( + "Backends to benchmark. Default: 'reference' plus 'hybrid' if importable." + ), + ) + p.add_argument( + "--passes", + nargs="+", + choices=ALL_PASSES, + default=ALL_PASSES, + help="Which passes to run: fwd, bwd, vag. Default: all three.", + ) + return p.parse_args() + + +def resolve_backends(requested): + if requested is None: + backends = ["reference"] + if _HAVE_HYBRID: + backends.append("hybrid") + return backends + if "hybrid" in requested and not _HAVE_HYBRID: + print( + f"WARNING: 'hybrid' backend requested but unavailable " + f"({type(_HYBRID_IMPORT_ERROR).__name__}: {_HYBRID_IMPORT_ERROR}). " + "Running it anyway — expect failure." + ) + return requested + + def main(): - print(f"jax devices: {jax.devices()}\n") + args = parse_args() + backends = resolve_backends(args.backends) + passes = args.passes + + print(f"jax devices: {jax.devices()}") + print(f"backends: {backends}") + print(f"passes: {passes}\n") + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: Q, K, W_uq, W_dq, W_k, W_w = make_inputs( B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 ) - args = (Q, K, W_uq, W_dq, W_k, W_w) + fn_args = (Q, K, W_uq, W_dq, W_k, W_w) fwd_flops = theoretical_fwd_flops(B, oH, T, S, d, d_c, H, d_i) print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") print(f" forward GFLOPs/call: {fwd_flops/1e9:.2f}") - print(f" bwd GFLOPs/call (~2x): {2*fwd_flops/1e9:.2f}") - print(f" f+b GFLOPs/call (~3x): {3*fwd_flops/1e9:.2f}") + if "bwd" in passes: + print(f" bwd GFLOPs/call (~2x): {2*fwd_flops/1e9:.2f}") + if "vag" in passes: + print(f" f+b GFLOPs/call (~3x): {3*fwd_flops/1e9:.2f}") print() - backends = ["reference"] - if _HAVE_HYBRID: - backends.append("hybrid") - - # Headers print(f" {'backend':<10s} {'pass':<14s} {'ms':>8s} {'TFLOP/s':>8s}") for backend in backends: - try: - # Forward (loss only) - fwd = _build_fwd(backend) - sec = time_fn(fwd, args) - ms = sec * 1e3 - tflops = fwd_flops / sec / 1e12 - print(f" {backend:<10s} {'forward':<14s} {ms:8.3f} {tflops:8.2f}") - - # Backward only (jax.grad — XLA may re-trace forward inside) - bwd = _build_bwd(backend) - sec = time_fn(bwd, args) - ms = sec * 1e3 - tflops = 2 * fwd_flops / sec / 1e12 # bwd ~= 2x fwd - print(f" {backend:<10s} {'backward':<14s} {ms:8.3f} {tflops:8.2f}") - - # value_and_grad (forward + backward, single pass) - vag = _build_value_and_grad(backend) - sec = time_fn(vag, args) - ms = sec * 1e3 - tflops = 3 * fwd_flops / sec / 1e12 # f+b ~= 3x fwd - print(f" {backend:<10s} {'value_and_grad':<14s} {ms:8.3f} {tflops:8.2f}") - except Exception as e: # noqa: BLE001 - print(f" {backend:<10s} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + for pass_key in passes: + label, builder, flop_mult = PASS_SPECS[pass_key] + try: + fn = builder(backend) + sec = time_fn(fn, fn_args) + ms = sec * 1e3 + tflops = flop_mult * fwd_flops / sec / 1e12 + print(f" {backend:<10s} {label:<14s} {ms:8.3f} {tflops:8.2f}") + except Exception as e: # noqa: BLE001 + msg = str(e).splitlines()[0] if str(e) else "" + print( + f" {backend:<10s} {label:<14s} FAILED: " + f"{type(e).__name__}: {msg}" + ) print() diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py index 82f39b177..70f59698d 100644 --- a/benchmarks/profile_indexer_topk.py +++ b/benchmarks/profile_indexer_topk.py @@ -66,7 +66,7 @@ def time_fn(fn, args, n_warmup=15, n_iter=50): CONFIGS = [ #(B, oH, T, S, d, d_c, H, d_i) - ( 2, 64, 4096, 4096, 512, 1024, 64, 128), + ( 2, 64, 1024, 1024, 512, 1024, 64, 128), ] K_TOPK = 512 From 2a27163213e629b0f8dddfdb41d079f47b884fba Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 2 Jun 2026 17:53:44 +0000 Subject: [PATCH 12/17] Minimized diff --- .../jax/triton_extensions/utils.py | 128 ++++++++---------- 1 file changed, 60 insertions(+), 68 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 69d939020..35a72738d 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -40,6 +40,7 @@ from jax import core import jax import jax.numpy as jnp +from transformer_engine.jax.util import is_hip_extension # Placeholder package version on PyPI that should never be used @@ -152,6 +153,7 @@ def _check_triton_compatibility(): try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc + from triton.backends.nvidia import compiler as cb from triton.runtime import autotuner except ImportError as e: raise ImportError( @@ -161,22 +163,10 @@ def _check_triton_compatibility(): ) from e -# Detect target platform once at import time. AMD/HIP returns an arch string -# like "gfx950:sramecc+:xnack-"; NVIDIA returns something else (or this call -# falls through to the CUDA path). -try: - _ARCH_DETAILS = gpu_triton.get_arch_details(0) -except Exception: # noqa: BLE001 - _ARCH_DETAILS = "" -_IS_HIP = _ARCH_DETAILS.startswith("gfx") - -# Lazy backend imports — only pull in what the active platform needs so that -# AMD-only or NVIDIA-only environments don't fail at module load. -if _IS_HIP: +# AMD/HIP backend imports are additive: the NVIDIA path above is left untouched. +if is_hip_extension(): from triton.backends.amd import compiler as cb_hip # noqa: E402 from triton.backends.compiler import GPUTarget as _TritonGPUTarget # noqa: E402 -else: - from triton.backends.nvidia import compiler as cb # noqa: E402 __all__ = ["triton_call_lowering", "get_triton_info"] @@ -230,7 +220,7 @@ def get_triton_dtype(aval): jnp.dtype("float16"): "fp16", jnp.dtype("float8_e4m3fn"): "fp8e4nv", jnp.dtype("float8_e5m2"): "fp8e5", - # AMD MI300 (gfx942) "FNUZ" variants — Triton calls these fp8e4b8/fp8e5b16. + # AMD gfx942 "FNUZ" variants — Triton calls these fp8e4b8/fp8e5b16. jnp.dtype("float8_e4m3fnuz"): "fp8e4b8", jnp.dtype("float8_e5m2fnuz"): "fp8e5b16", jnp.dtype("int64"): "i64", @@ -294,29 +284,12 @@ def compile_triton( if cache_key in _TRITON_KERNEL_CACHE: return _TRITON_KERNEL_CACHE[cache_key] - # Mark constants as constexpr in signature (defensive — tensor signatures - # built by triton_call_lowering won't contain constexpr names, but other - # callers might). - signature_with_constexpr = dict(signature) - for const_name in constants.keys(): - if const_name in signature_with_constexpr: - signature_with_constexpr[const_name] = "constexpr" - - if _IS_HIP: + # AMD/HIP uses a separate compilation path; the NVIDIA path below is the + # unchanged upstream implementation. + if is_hip_extension(): kernel = _compile_triton_hip( kernel_fn, - signature_with_constexpr, - constants, - num_warps, - num_stages, - num_ctas, - compute_capability, - enable_fp_fusion, - ) - else: - kernel = _compile_triton_cuda( - kernel_fn, - signature_with_constexpr, + signature, constants, num_warps, num_stages, @@ -324,21 +297,10 @@ def compile_triton( compute_capability, enable_fp_fusion, ) + _TRITON_KERNEL_CACHE[cache_key] = kernel + return kernel - _TRITON_KERNEL_CACHE[cache_key] = kernel - return kernel - - -def _compile_triton_cuda( - kernel_fn, - signature, - constants, - num_warps, - num_stages, - num_ctas, - compute_capability, - enable_fp_fusion, -): + # Compile kernel options = cb.CUDAOptions( num_warps=num_warps, num_stages=num_stages, @@ -347,36 +309,54 @@ def _compile_triton_cuda( debug=False, enable_fp_fusion=enable_fp_fusion, ) - src = tc.ASTSource(fn=kernel_fn, constexprs=constants, signature=signature) + + # Mark constants as constexpr in signature + signature_with_constexpr = dict(signature) + for const_name in constants.keys(): + if const_name in signature_with_constexpr: + signature_with_constexpr[const_name] = "constexpr" + + src = tc.ASTSource( + fn=kernel_fn, + constexprs=constants, + signature=signature_with_constexpr, + ) + compiled = tc.compile( src, target=tc.GPUTarget("cuda", compute_capability, 32), options=options.__dict__, ) + # Create kernel object for JAX + # From jax/jaxlib/gpu/triton_kernels.cc: from packaging import version if version.parse(jax.__version__) >= version.parse("0.8.2"): - return gpu_triton.TritonKernel( + kernel = gpu_triton.TritonKernel( + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) + ) + else: + kernel = gpu_triton.TritonKernel( compiled.name, num_warps, - num_ctas, compiled.metadata.shared, compiled.asm["ptx"], - "", + "", # ttir compute_capability, + 1, + 1, + 1, ) - return gpu_triton.TritonKernel( - compiled.name, - num_warps, - compiled.metadata.shared, - compiled.asm["ptx"], - "", - compute_capability, - 1, - 1, - 1, - ) + + _TRITON_KERNEL_CACHE[cache_key] = kernel + return kernel # Track HSACO temp files for the lifetime of the process so the kernel paths @@ -394,8 +374,9 @@ def _compile_triton_hip( compute_capability, enable_fp_fusion, ): - # Strip target-feature suffix: "gfx950:sramecc+:xnack-" -> "gfx950". - arch = _ARCH_DETAILS.split(":", 1)[0] + # AMD/HIP returns an arch string like "gfx950:sramecc+:xnack-"; strip the + # target-feature suffix -> "gfx950". + arch = gpu_triton.get_arch_details(0).split(":", 1)[0] # Mirror what triton's parse_options would do per-arch: the default # HIPOptions.supported_fp8_dtypes is just ("fp8e5",), and constructing # HIPOptions directly bypasses the per-arch augmentation. Set it @@ -416,7 +397,18 @@ def _compile_triton_hip( arch=arch, supported_fp8_dtypes=fp8_dtypes, ) - src = tc.ASTSource(fn=kernel_fn, constexprs=constants, signature=signature) + + # Mark constants as constexpr in signature (mirrors the NVIDIA path). + signature_with_constexpr = dict(signature) + for const_name in constants.keys(): + if const_name in signature_with_constexpr: + signature_with_constexpr[const_name] = "constexpr" + + src = tc.ASTSource( + fn=kernel_fn, + constexprs=constants, + signature=signature_with_constexpr, + ) compiled = tc.compile( src, target=_TritonGPUTarget("hip", arch, warp_size=64), From 40bd8cc48a1469ef82d4874eb96b41bbe97760d8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 2 Jun 2026 17:56:07 +0000 Subject: [PATCH 13/17] Removed comment --- transformer_engine/jax/triton_extensions/utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 35a72738d..2cf01e19f 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -421,13 +421,6 @@ def _compile_triton_hip( f.write(compiled.asm["hsaco"]) _HSACO_TEMP_FILES.append(hsaco_path) - # The HIP TritonKernel constructor on this jax/jaxlib (0.8.0) takes - # `shared_mem_bytes` in slot 2 — not slot 5 as the public sample code - # suggests. The sample only works for kernels whose `shared` is 0 - # (e.g. simple element-wise kernels), because there the misplaced 0 in - # slot 2 coincidentally matches the expected layout. Kernels using - # tl.dot need real LDS allocation and silently produce garbage when - # `shared` lands in the wrong constructor slot. return gpu_triton.TritonKernel( compiled.name, num_warps, From e27f4d5151101a89656577043cdc264ae5f04aa7 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 2 Jun 2026 18:08:49 +0000 Subject: [PATCH 14/17] Minimize diff --- .../jax/triton_extensions/utils.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2cf01e19f..8a75a1cc4 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -441,9 +441,9 @@ def triton_call_lowering( grid, input_output_aliases: Mapping[int, int] = None, constexprs: Mapping[str, Any] = None, - num_warps: int = 32, - num_stages: int = 1, - num_ctas: int = 1, + num_warps: int = None, + num_stages: int = None, + num_ctas: int = None, ): """Helper for MLIR lowering that calls a Triton kernel. @@ -518,9 +518,16 @@ def _normalize_grid(g): # For non-autotune fallback, evaluate with just the user constexprs. grid_tuple = _normalize_grid(grid_fn(constexprs or {})) - # Caller-supplied num_warps/num_stages/num_ctas (defaults match the - # historical hardcoded values: 32/1/1). + # Default values for the kernel actual_kernel_fn = kernel_fn + if num_warps is None: + num_warps = 32 + if num_stages is None: + num_stages = ( + 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas + ) + if num_ctas is None: + num_ctas = 1 kernel_constexprs = constexprs if constexprs is not None else {} # Handle autotuned kernels - compile all configs @@ -541,11 +548,9 @@ def _normalize_grid(g): # Per-config grid: re-evaluate grid_fn with this config's merged # kwargs so configs that vary BLOCK_T/BLOCK_S launch at the right - # cdiv(T_t, BLOCK_T) etc. + # cdiv(T_t, BLOCK_T) etc. (grid_tuple is otherwise the fixed grid.) if grid_fn is not None: - config_grid = _normalize_grid(grid_fn(config_constexprs)) - else: - config_grid = grid_tuple + grid_tuple = _normalize_grid(grid_fn(config_constexprs)) # Compile this config config_kernel = compile_triton( @@ -566,9 +571,9 @@ def _normalize_grid(g): config_call = gpu_triton.TritonKernelCall( config_kernel, - config_grid[0], - config_grid[1], - config_grid[2], + grid_tuple[0], + grid_tuple[1], + grid_tuple[2], config_params, ) From 30faf3c1be05635741a54593607a3a45693da81e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 2 Jun 2026 19:16:49 +0000 Subject: [PATCH 15/17] Refactored package structure for new sparse attention components --- tests/jax/test_indexer.py | 4 +- tests/jax/test_sparse_attention.py | 4 +- transformer_engine/jax/__init__.py | 6 +-- .../jax/sparse_attention/__init__.py | 43 +++++++++++++++++++ .../compressed_attention.py | 4 +- .../dsa.py} | 4 +- .../jax/{ => sparse_attention}/indexer.py | 0 .../jax/triton_extensions/utils.py | 2 +- 8 files changed, 53 insertions(+), 14 deletions(-) create mode 100644 transformer_engine/jax/sparse_attention/__init__.py rename transformer_engine/jax/{ => sparse_attention}/compressed_attention.py (96%) rename transformer_engine/jax/{sparse_attention.py => sparse_attention/dsa.py} (99%) rename transformer_engine/jax/{ => sparse_attention}/indexer.py (100%) diff --git a/tests/jax/test_indexer.py b/tests/jax/test_indexer.py index c04d92a2d..66476206e 100644 --- a/tests/jax/test_indexer.py +++ b/tests/jax/test_indexer.py @@ -4,7 +4,7 @@ """Correctness tests for the lightning-indexer JAX ops. Ported from the in-module ``__main__`` smoke tests of -``transformer_engine.jax.indexer``. The hybrid and top-k backends require +``transformer_engine.jax.sparse_attention.indexer``. The hybrid and top-k backends require rank-4 ``(B, oH, T, d)`` inputs, so every leading shape here is length-2. """ @@ -12,7 +12,7 @@ import jax.numpy as jnp import pytest -from transformer_engine.jax.indexer import indexer, indexer_topk +from transformer_engine.jax.sparse_attention.indexer import indexer, indexer_topk def _indexer_inputs(B, oH, T_t, T_s, d, d_c, H, d_i, seed): diff --git a/tests/jax/test_sparse_attention.py b/tests/jax/test_sparse_attention.py index d486f24c8..a911b6049 100644 --- a/tests/jax/test_sparse_attention.py +++ b/tests/jax/test_sparse_attention.py @@ -14,11 +14,11 @@ _causal_keep_mask, _topk_indices_to_attn_mask, ) -from transformer_engine.jax.compressed_attention import ( +from transformer_engine.jax.sparse_attention.compressed_attention import ( HeavilyCompressedAttention, heavily_compressed_attention, ) -from transformer_engine.jax.indexer import indexer +from transformer_engine.jax.sparse_attention.indexer import indexer from transformer_engine.jax.triton_extensions import fused_sparse_attention_triton diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index bd2bf44f2..9172a9867 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,10 +34,8 @@ from . import flax from . import quantize -# AMD lightning-indexer / sparse-attention staging modules. -from . import indexer +# AMD lightning-indexer / sparse-attention staging module. from . import sparse_attention -from . import compressed_attention from .quantize import autocast, fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME @@ -56,7 +54,5 @@ "MeshResource", "flax", "quantize", - "indexer", "sparse_attention", - "compressed_attention", ] diff --git a/transformer_engine/jax/sparse_attention/__init__.py b/transformer_engine/jax/sparse_attention/__init__.py new file mode 100644 index 000000000..e8e049284 --- /dev/null +++ b/transformer_engine/jax/sparse_attention/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Deep Sparse Attention (DSA) family. + +Bundles the lightning indexer and the attention modules built on top of it: + + * :mod:`~transformer_engine.jax.sparse_attention.indexer` — the lightning + indexer op (``indexer`` / ``indexer_topk``). + * :mod:`~transformer_engine.jax.sparse_attention.dsa` — Deep Sparse + Attention, which composes the indexer with dense attention. + * :mod:`~transformer_engine.jax.sparse_attention.compressed_attention` — + Heavily Compressed Attention (MLA-style scaffold, design deferred). + +The Triton kernel backends live in +:mod:`transformer_engine.jax.triton_extensions` alongside the other Triton +kernels. +""" + +from . import indexer +from . import dsa +from . import compressed_attention + +from .dsa import ( + DeepSparseAttention, + deep_sparse_attention_core, + _causal_keep_mask, + _topk_indices_to_attn_mask, +) +from .compressed_attention import ( + HeavilyCompressedAttention, + heavily_compressed_attention, +) + +__all__ = [ + "indexer", + "dsa", + "compressed_attention", + "DeepSparseAttention", + "deep_sparse_attention_core", + "HeavilyCompressedAttention", + "heavily_compressed_attention", +] diff --git a/transformer_engine/jax/compressed_attention.py b/transformer_engine/jax/sparse_attention/compressed_attention.py similarity index 96% rename from transformer_engine/jax/compressed_attention.py rename to transformer_engine/jax/sparse_attention/compressed_attention.py index 16de17457..3d7aee172 100644 --- a/transformer_engine/jax/compressed_attention.py +++ b/transformer_engine/jax/sparse_attention/compressed_attention.py @@ -19,7 +19,7 @@ apply RoPE to (Q_rope, K_rope) O = softmax(Q @ K^T / sqrt(d)) @ V -See ``transformer_engine.jax.sparse_attention`` for the sibling DSA module +See ``transformer_engine.jax.sparse_attention.dsa`` for the sibling DSA module that is implemented today. """ @@ -42,7 +42,7 @@ " 5. Interaction with TE's existing fused-attn backends — does any of " "CK/AITER/cuDNN support split (RoPE/no-RoPE) head dims natively?\n" "Pin these before filling in. See " - "transformer_engine.jax.sparse_attention for the working DSA module." + "transformer_engine.jax.sparse_attention.dsa for the working DSA module." ) diff --git a/transformer_engine/jax/sparse_attention.py b/transformer_engine/jax/sparse_attention/dsa.py similarity index 99% rename from transformer_engine/jax/sparse_attention.py rename to transformer_engine/jax/sparse_attention/dsa.py index 3a7df3901..e3ee8d741 100644 --- a/transformer_engine/jax/sparse_attention.py +++ b/transformer_engine/jax/sparse_attention/dsa.py @@ -48,7 +48,7 @@ from transformer_engine.jax.flax.module import DenseGeneral from transformer_engine.jax.flax.transformer import DotProductAttention -from transformer_engine.jax.indexer import indexer as _indexer_fn, _indexer_projections +from .indexer import indexer as _indexer_fn, _indexer_projections # Backends supported by deep_sparse_attention_core. @@ -376,7 +376,7 @@ def __call__( )(inputs_kv) # [B, oH, T_s, head_dim] # ---- indexer projections (shared across oH) ---- - # Shapes mirror transformer_engine.jax.indexer:31-48. + # Shapes mirror transformer_engine.jax.sparse_attention.indexer:31-48. W_dq = self.param( "indexer_W_dq", nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), diff --git a/transformer_engine/jax/indexer.py b/transformer_engine/jax/sparse_attention/indexer.py similarity index 100% rename from transformer_engine/jax/indexer.py rename to transformer_engine/jax/sparse_attention/indexer.py diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 8a75a1cc4..1c3baf2c4 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -518,7 +518,7 @@ def _normalize_grid(g): # For non-autotune fallback, evaluate with just the user constexprs. grid_tuple = _normalize_grid(grid_fn(constexprs or {})) - # Default values for the kernel + # Default values for the kernel (used unless the caller overrides them). actual_kernel_fn = kernel_fn if num_warps is None: num_warps = 32 From 03a7447992043ca31783e11f298115d4a30c1f74 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 2 Jun 2026 20:03:41 +0000 Subject: [PATCH 16/17] Updated benchmark scripts, added indexer class --- benchmarks/profile_indexer.py | 2 +- benchmarks/profile_indexer_topk.py | 2 +- tests/jax/test_indexer.py | 38 +++++++++- .../jax/sparse_attention/__init__.py | 2 + .../jax/sparse_attention/indexer.py | 74 ++++++++++++++++++- 5 files changed, 114 insertions(+), 4 deletions(-) diff --git a/benchmarks/profile_indexer.py b/benchmarks/profile_indexer.py index 481987931..43f824139 100644 --- a/benchmarks/profile_indexer.py +++ b/benchmarks/profile_indexer.py @@ -12,7 +12,7 @@ import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import indexer +from transformer_engine.jax.sparse_attention.indexer import indexer try: from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py index 70f59698d..68b9c86a5 100644 --- a/benchmarks/profile_indexer_topk.py +++ b/benchmarks/profile_indexer_topk.py @@ -14,7 +14,7 @@ import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import indexer, indexer_topk +from transformer_engine.jax.sparse_attention.indexer import indexer, indexer_topk try: from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 diff --git a/tests/jax/test_indexer.py b/tests/jax/test_indexer.py index 66476206e..0f3597777 100644 --- a/tests/jax/test_indexer.py +++ b/tests/jax/test_indexer.py @@ -12,7 +12,11 @@ import jax.numpy as jnp import pytest -from transformer_engine.jax.sparse_attention.indexer import indexer, indexer_topk +from transformer_engine.jax.sparse_attention.indexer import ( + LightningIndexer, + indexer, + indexer_topk, +) def _indexer_inputs(B, oH, T_t, T_s, d, d_c, H, d_i, seed): @@ -85,3 +89,35 @@ def inner(*a): grads_hyb = jax.grad(_loss("hybrid"), argnums=argnums)(*args) for gr, gh in zip(grads_ref, grads_hyb): assert _rel_err(gh, gr) < 5e-2 + + +def test_lightning_indexer_module_matches_functional(): + """``LightningIndexer`` (Flax module) reproduces the functional ``indexer`` + when fed the module's own initialized weights.""" + B, oH, T_t, T_s, d, d_c, H, d_i = 2, 3, 64, 64, 32, 32, 8, 32 + keys = jax.random.split(jax.random.PRNGKey(7), 3) + Q = jax.random.normal(keys[0], (B, oH, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (B, oH, T_s, d), dtype=jnp.bfloat16) + + mod = LightningIndexer(num_heads=H, d_c=d_c, d_i=d_i, backend="reference") + variables = mod.init(keys[2], Q, K) + o_mod = mod.apply(variables, Q, K) + assert o_mod.shape == (B, oH, T_t, T_s) + + p = variables["params"] + o_fn = indexer(Q, K, p["W_uq"], p["W_dq"], p["W_k"], p["W_w"], backend="reference") + assert _rel_err(o_mod, o_fn) < 1e-5 + + +def test_lightning_indexer_topk_mode(): + """``LightningIndexer(topk=k)`` returns fused top-k indices of shape (..., T, k).""" + B, oH, T_t, T_s, d, d_c, H, d_i, k = 2, 3, 64, 128, 32, 32, 16, 32, 32 + keys = jax.random.split(jax.random.PRNGKey(9), 2) + Q = jax.random.normal(keys[0], (B, oH, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (B, oH, T_s, d), dtype=jnp.bfloat16) + + mod = LightningIndexer(num_heads=H, d_c=d_c, d_i=d_i, topk=k) + variables = mod.init(jax.random.PRNGKey(0), Q, K) + idx = mod.apply(variables, Q, K) + assert idx.shape == (B, oH, T_t, k) + assert idx.dtype == jnp.int32 diff --git a/transformer_engine/jax/sparse_attention/__init__.py b/transformer_engine/jax/sparse_attention/__init__.py index e8e049284..8a66d05d4 100644 --- a/transformer_engine/jax/sparse_attention/__init__.py +++ b/transformer_engine/jax/sparse_attention/__init__.py @@ -21,6 +21,7 @@ from . import dsa from . import compressed_attention +from .indexer import LightningIndexer from .dsa import ( DeepSparseAttention, deep_sparse_attention_core, @@ -36,6 +37,7 @@ "indexer", "dsa", "compressed_attention", + "LightningIndexer", "DeepSparseAttention", "deep_sparse_attention_core", "HeavilyCompressedAttention", diff --git a/transformer_engine/jax/sparse_attention/indexer.py b/transformer_engine/jax/sparse_attention/indexer.py index a01bb60b9..98e93f5fa 100644 --- a/transformer_engine/jax/sparse_attention/indexer.py +++ b/transformer_engine/jax/sparse_attention/indexer.py @@ -11,7 +11,9 @@ registers. Avoids the score-tensor HBM round-trip that dominates the reference path. -Top-level entry point: ``indexer(Q, K, W_uq, W_dq, W_k, W_w, *, backend=...)``. +Functional entry point: ``indexer(Q, K, W_uq, W_dq, W_k, W_w, *, backend=...)``. +User-facing Flax module: :class:`LightningIndexer`, which owns the projection +weights and delegates to ``indexer`` / ``indexer_topk``. Math (low-rank form: Q is hidden state; query heads are produced by a down-projection (d -> d_c) followed by an up-projection (d_c -> H * d_i); @@ -26,9 +28,11 @@ """ import functools +from typing import Optional import jax import jax.numpy as jnp +from flax import linen as nn def _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w): @@ -126,3 +130,71 @@ def indexer(Q, K, W_uq, W_dq, W_k, weights, *, out_dtype=None, backend="referenc raise ValueError( f"unknown backend {backend!r}; expected 'reference' or 'hybrid'" ) + + +class LightningIndexer(nn.Module): # pylint: disable=too-few-public-methods + """Lightning-indexer Flax module — the user-facing indexer API. + + Owns the low-rank indexer projection weights (``W_dq``, ``W_uq``, ``W_k``, + ``W_w``) and delegates to the functional :func:`indexer` / :func:`indexer_topk` + ops. Weight shapes mirror :func:`indexer`'s ``Args`` and are inferred from the + trailing hidden dimension ``d`` of ``Q`` at call time. + + Parameters + ---------- + num_heads : int + Number of indexer-internal heads (``H``). + d_c : int + Down-projection rank (``d -> d_c``). + d_i : int + Inner head dimension (``d_i``). + topk : Optional[int], default ``None`` + If set, :meth:`__call__` returns the fused top-``k`` indices + (``(..., T, k)`` int32) via :func:`indexer_topk`, and ``backend`` / + ``out_dtype`` are ignored (top-k always uses the fused Triton kernel). + If ``None``, :meth:`__call__` returns the full score tensor + ``(..., T, S)``. + backend : str, default ``"reference"`` + ``"reference"`` (pure einsum) or ``"hybrid"`` (Triton score-relu-reduce). + Only used when ``topk is None``. + out_dtype : Optional[jnp.dtype] + Output dtype override; defaults to ``Q.dtype``. Unused when ``topk`` is set. + dtype : Optional[jnp.dtype] + Parameter dtype. Defaults to the input dtype. + """ + + num_heads: int + d_c: int + d_i: int + topk: Optional[int] = None + backend: str = "reference" + out_dtype: Optional[jnp.dtype] = None + dtype: Optional[jnp.dtype] = None + + @nn.compact + def __call__(self, Q: jax.Array, K: jax.Array) -> jax.Array: + """Run the indexer on ``Q`` / ``K``. + + Args: + Q: ``(..., T, d)`` query-side hidden state. + K: ``(..., S, d)`` key-side hidden state. + + Returns: + ``(..., T, S)`` scores if ``topk is None``, else ``(..., T, k)`` + int32 top-k indices (in descending score order). + """ + d = Q.shape[-1] + param_dtype = self.dtype if self.dtype is not None else Q.dtype + init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + + W_dq = self.param("W_dq", init, (d, self.d_c), param_dtype) + W_uq = self.param("W_uq", init, (self.num_heads, self.d_c, self.d_i), param_dtype) + W_k = self.param("W_k", init, (d, self.d_i), param_dtype) + W_w = self.param("W_w", init, (d, self.num_heads), param_dtype) + + if self.topk is not None: + return indexer_topk(Q, K, W_uq, W_dq, W_k, W_w, k=self.topk) + return indexer( + Q, K, W_uq, W_dq, W_k, W_w, + out_dtype=self.out_dtype, backend=self.backend, + ) From 223ca5a69109c7385995317bd090d187b1e48166 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 5 Jun 2026 19:37:41 +0000 Subject: [PATCH 17/17] Corrected import --- benchmarks/profile_indexer_bwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/profile_indexer_bwd.py b/benchmarks/profile_indexer_bwd.py index 85363141f..4e2688078 100644 --- a/benchmarks/profile_indexer_bwd.py +++ b/benchmarks/profile_indexer_bwd.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp -from transformer_engine.jax.indexer import indexer +from transformer_engine.jax.sparse_attention.indexer import indexer try: from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401