Skip to content

feat(ck-tile): TE to dispatcher GEMM bridge for fp8/bf8/int8 (all layouts)#8998

Open
ozturkosu wants to merge 16 commits into
developfrom
users/muozturk/ck-tile/gemm-bridge-all-layout-fp8-bf8-int8
Open

feat(ck-tile): TE to dispatcher GEMM bridge for fp8/bf8/int8 (all layouts)#8998
ozturkosu wants to merge 16 commits into
developfrom
users/muozturk/ck-tile/gemm-bridge-all-layout-fp8-bf8-int8

Conversation

@ozturkosu

@ozturkosu ozturkosu commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Re-opened from #8887 with a policy-compliant branch name (users/muozturk/ck-tile/gemm-bridge-all-layout-fp8-bf8-int8). Supersedes #8887.

Summary

Extends the Tile Engine ↔ Dispatcher GEMM bridge to the remaining data types TE's plain GEMM has MFMA warp tiles for, beyond the fp16/bf16 surface of #8479:

  • fp8 (E4M3) and bf8 (E5M2) → fp16 output, fp32 accumulate
  • int8 → int32 output and accumulate (gfx942)

All four A/B layout combinations per dtype (row-major C only, matching #8479). fp32/fp64 are intentionally excluded — they appear in TE's dtype-string map but have no MFMA warp tiles in GEMM_WARP_TILE_SUPPORTED_COMBINATIONS, so no kernel can be generated/run.

Depends on the fp16/bf16 bridge in #8997 (users/muozturk/ck-tile/gemm-bridge-all-layout-bf16-fp16), which carries the bridge infrastructure and is not yet merged. This PR targets develop, so until #8997 merges its diff also includes the base bridge changes; please merge #8997 first.

Changes

  • Codegen (codegen_common.py, unified_gemm_codegen.py): add int32 to the dtype maps; get_output_dtype int8→int32; new get_acc_dtype (int8→int32, else fp32); derive AccDataType/CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros, and the registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32.
  • Host harness (gemm_utils.py): fp8/bf8 FNUZ (gfx942) uint8 codecs — exact decode (matches device fp8_t/bf8_t), nearest-representable saturating encode (same pattern as the existing bf16 helper); GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype; expand_sweep sets dtype_c/dtype_acc.
  • Tests: test_gemm_utils.py adds CPU-only fp8/bf8 codec + output-dtype tests (all green); test_gemm_parity.py adds fp8/bf8/int8 cases with dtype-aware inputs/references/tolerances (int8 is bit-exact), GPU-gated like the existing cases.

Verification done

  • test_gemm_utils.py + test_codegen_common.py: 54 passed (CPU).
  • Codegen smoke: fp8/int8/fp16 each generate 1 kernel + 1 wrapper, 0 failed; emitted ADataType/CDataType/AccDataType and GEMM_KEY_* macros are correct (int8→int32_t acc/C; fp8→fp16_t C).
  • test_gemm_parity.py collects 60 cases and skips cleanly without a GPU.
  • The 16 unrelated failures in test_examples_integration / test_grouped_conv_codegen / test_library_caching are pre-existing (verified identical on the base branch; they require a built dispatcher .a / GPU).

Test plan

  • Merge feat(ck-tile): TE to dispatcher GEMM bridge (fp16/bf16, all layouts) #8997 (fp16/bf16 bridge), then this reduces to just the fp8/bf8/int8 delta on develop.
  • On an MI300X (gfx942) node: run python3 tests/test_gemm_parity.py and confirm fp8/bf8/int8 parity; tune the fp8/bf8 tolerances if needed (current values are first-cut headroom).
  • FNUZ vs OCP: the fp8/bf8 host codec targets the gfx942 FNUZ format; validate / extend for gfx950 (OCP) before enabling there.

ozturkosu and others added 15 commits June 16, 2026 00:25
Consolidated, single-commit GEMM bridge routing the Tile Engine regular-GEMM
sweep through the Dispatcher (codegen -> build -> runtime), so the Dispatcher is
the single source of truth and the Tile Engine owns only the config search space
and the benchmark loop. Mirrors the FMHA/Conv reference binding end to end.

Scope:
- Regular GEMM bridge: unified_gemm_codegen.py, gemm_ctypes_lib.cpp (flat
  extern "C" ABI, host-pointer model), gemm_utils.py (GemmKernelConfig with
  byte-exact .name, one-.so-per-kernel build), 3-phase TE driver + subprocess
  worker (gemm_full_benchmark.py / run_one_gemm_kernel.py).
- Trait-derived registry KernelKey (replaces the hard-coded fp16/rcr key).
- bf16 support and all four layouts (rcr/rrr/crr/ccr; row-major C only).
- Tile Engine AMDGPU -mllvm codegen-flag parity + arch-validated tile filtering.
- --verify fp32-reference correctness gate; multi-GPU fan-out.
- Runnable example (examples/gemm/python/12_te_bridge.py) and parity/unit tests.
- Removes the legacy standalone gemm_universal build path and the old
  test/ck_tile/gemm_tile_engine harness; promotes sweep configs to the op-root
  flat configs/ directory (fmha/grouped_conv convention).

Validated on gfx942 / MI300X (fp16 + bf16, all four layouts) against an fp32
numpy reference via --verify.
The bridge dispatcher's tile-divisibility gate rejected any problem where
M % TileM != 0 for every layout, returning status -2 ("No suitable kernel")
at runtime even though the .so built fine. This wrongly excluded bf16 rcr/rrr
kernels with a non-power-of-two TileM (e.g. 192) on standard shapes like
1024^3 -- cases Old-TE compiles, runs, and verifies as correct.

Root cause: supports() was layout-blind, while the underlying
ck_tile::GemmKernel::IsSupportedArgument only constrains a dimension when an
operand whose inner axis is that dimension participates without padding:

  RowMajor A -> K, ColMajor A -> M
  RowMajor B -> N, ColMajor B -> K
  RowMajor C -> N, ColMajor C -> M

So for rcr (RowMajor A & C) M is never gated, which is why Old-TE runs M=192
tiles on M-indivisible problems.

Make supports() compute require_m/n/k from the kernel key's A/B/C layouts so
it mirrors IsSupportedArgument exactly (also honoring k_batch in the K grain).
Anything it now lets through is still validated by the kernel's own
IsSupportedArgument inside launch(), so the bridge stays a strict functional
equivalent of Old-TE. Applied to both generated_tile_backend.hpp (the GEMM
.so path) and the sibling tile_backend.hpp.

Validated on gfx942 (MI300X): 85 previously status-2 rcr/rrr bf16 192-tile
.so now run at 1024^3 (Old-TE runs the same, verification correct); the 8
remaining rejects are tile N=192 cases that Old-TE also reports "Arguments
not supported" at N=1024 -- parity preserved in both directions.
…oding rcr

dispatcher_initialize() in gemm_ctypes_lib.cpp hardcoded the KernelKey layout to
rcr (RowMajor/ColMajor/RowMajor) for every kernel. Now that supports() is
layout-aware, that wrong key layout makes the dispatcher reject valid problems:
a crr kernel does not gate K (neither A=ColMajor nor B=RowMajor has K as its
inner axis), but with a hardcoded rcr key supports() applies rcr's K-gate and
returns status -2 for TileK=192 problems (e.g. crr 64x64x192 at 1024^3) that
Old-TE compiles, runs, and verifies (~87 TFLOPS).

Derive signature.layout_a/b/c from the force-included kernel's own
ALayout/BLayout/CLayout types via std::is_same_v with tensor_layout::gemm::RowMajor.
The key now matches the kernel, so the layout-aware gate is correct for all four
layouts. Execution was already layout-correct (the kernel uses its own compile-time
layouts); only the host-side selection metadata was wrong.

Validated on gfx942 (MI300X): crr 64x64x192 now runs on the bridge (93 TFLOPS),
restoring parity with Old-TE.
The >=20% bridge-vs-old-TE perf gaps in the parity sweep are a harness
artifact: the sweep timed the bridge in-process but timed old-TE via its
separate standalone benchmark binary, which runs the byte-identical kernel
at a lower sustained SCLK. Measured through one harness the gap is <1%.

ab_same_harness.py removed that artifact but hardcoded the old-TE header dir
to fp16/rcr. Derive it per stem as <base>/<dtype>/<layout> so one run covers
rcr/rrr/ccr/crr and fp16+bf16, add a --stems-file/--csv resume-aware sweep
mode, and use the median (not max) per point.
For a full ~2000-stem sweep on a single GPU: batch all shapes into one worker
call per side (5x fewer process startups), cache the compiled old-TE .so, and
add a parallel --build-only pre-pass so hipcc compilation uses all CPU cores
while GPU measurement stays serial.
… guard)

The bridge-vs-old-TE A/B reported phantom regressions from two MEASUREMENT
bugs, not real codegen gaps:

- ab_same_harness.py built the old-TE side WITHOUT the TE codegen flags the
  bridge (and real old-TE's own CMake) use, so -enable-post-misched defaulted
  back on and old-TE ran ~10-40% faster -> the bridge looked regressed when it
  is at parity. Now both sides build with identical flags.

- ab_efficient_sweep.py measured whatever libgemm_<stem>.so existed with no
  freshness check, so 3-day-old binaries built from an obsolete codegen showed
  up as -78%/+703% gaps. Added a guard: skip any .so older than its generated
  header (treated as missing) instead of reporting a phantom gap.

With both fixes the 41 former >15% outlier stems measure within +/-10%
(median +0.01%); no bridge codegen regression exists.

Note: a separate, deliberately UNCOMMITTED perf change in gemm_utils.py (gate
-enable-post-misched=0 on persistent) gives non-persistent large tiles ~9-40%;
held back pending a broader persistent-kernel no-regression sweep.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
The bridge compiles each kernel .so with a hand-maintained hipcc flag list
(dispatcher/python/gemm_utils.py) that had drifted from Tile Engine's CMake
flags, so the bridge .so and the TE benchmark were not compiled apples-to-apples:

  * MISSING  -mllvm -amdgpu-coerce-illegal-types=1  (TE's CMakeLists.txt adds it
             when the compiler accepts it; the bridge build never did)
  * EXTRA    -mllvm -enable-noalias-to-md-conversion=0  (not a TE GEMM flag; it
             only appears in standalone CK examples/tests, never the TE gemm path)

Align the bridge's backend codegen flags with the exact set the TE
gemm_universal benchmark TU is built with. The coerce flag is added through a
cached hipcc probe that mirrors TE's check_cxx_compiler_flag, so the bridge stays
matched to TE on every toolchain (present where TE has it, skipped where TE's
CMake would skip it too).

The generated kernel source was already identical between the two engines; this
makes their compilation identical as well.
…y_diag/regression

Old-TE must remain until the dispatcher bridge implements every datatype Old-TE
supports, so revert the Old-TE removal from the bridge commit and re-wire its build:

  * restore test/ck_tile/gemm_tile_engine/* (10 files)
  * restore tile_engine/ops/gemm/gemm_universal/* (6 files: benchmark / instance
    builder / profiler / single-bench / CMakeLists)
  * re-add `add_subdirectory(gemm_universal EXCLUDE_FROM_ALL)` in
    tile_engine/ops/gemm/CMakeLists.txt; restore test/ck_tile/CMakeLists.txt to the
    develop state (gemm_tile_engine entry kept commented, as in develop)

Also drop the parity_diag/regression dev scripts that should not ship in the PR:
  * dispatcher/parity_diag/regression/ab_efficient_sweep.py
  * dispatcher/parity_diag/regression/ab_same_harness.py
…whitespace

- Add AMD copyright/SPDX header to gemm_full_benchmark.py and
  run_one_gemm_kernel.py (CK requires a header on every source file).
- Remove a trailing-whitespace blank line in generated_tile_backend.hpp
  that would trip the whitespace/clang-format CI gate.
….g. 192)

The CShuffle epilogue stores the accumulator back through LDS in power-of-two
MRepeat/NRepeat chunks, where MRepeat = tile_m / (wave_m * warp_tile_m) (and
likewise N). A tile whose per-wave repeat is not a power of two (or whose tile
dim is not divisible by wave*warp_tile) is mis-stored and produces numerically
WRONG results at runtime -- yet it still passes the ctypes validator and the
epilogue's static_asserts, so it compiles and silently returns garbage.

Observed on MI350 for tile_m=192 (MRepeat = 192/(2*32) = 3) and tile_n=192
(e.g. 64x192x64_1x4x1, 192 not divisible by 4*32): both verified incorrect
(fp32 reference, max_rel ~1.2-1.4) on the bridge AND Tile Engine, at every
shape including shapes divisible by 192. Power-of-two tiles (64/128/256) are
unaffected; a control 256-tile verifies cleanly (max_rel ~4e-4).

Add a validity gate in both tile-expansion paths:
  * unified_gemm_codegen.py::_get_tile_configs (codegen CLI path)
  * gemm_utils.py::expand_sweep (bridge .so build path; this path only ran the
    ctypes validate_kernel_config, which does not catch this)
so invalid tiles are dropped instead of emitted/run. tile_k is unaffected (the
K reduction has no CShuffle store constraint).
…(all layouts)

Adds the remaining data types Tile Engine's plain GEMM has MFMA warp tiles for
beyond the fp16/bf16 surface of PR #8479: fp8 (E4M3) and bf8 (E5M2) accumulating
into fp16, and int8 accumulating into int32 (gfx942). Covers all four A/B layout
combinations per dtype (row-major C only, as ck_tile rejects column-major C).

Codegen (codegen_common.py, unified_gemm_codegen.py):
- add int32 to the CK / qualified / dispatcher dtype maps
- get_output_dtype: int8 -> int32 (fp8/bf8 -> fp16 unchanged)
- new get_acc_dtype: int8 -> int32, else fp32
- derive AccDataType, CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros and the
  registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32

Host harness (gemm_utils.py):
- fp8/bf8 FNUZ (gfx942) uint8 codecs: exact decode (matches device fp8_t/bf8_t),
  nearest-representable saturating encode, mirroring the existing bf16 helper
- GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype (fp16 for
  fp8/bf8, int32 for int8)
- expand_sweep sets dtype_c/dtype_acc from the input dtype

Tests:
- test_gemm_utils.py: fp8/bf8 codec round-trip, format ranges, NaN/zero slots,
  saturation, byte size; output-dtype mapping (CPU-only)
- test_gemm_parity.py: fp8/bf8/int8 cases with dtype-aware inputs, references and
  tolerances (int8 exact); GPU-gated like the existing fp16/bf16 cases

GPU parity validation deferred to a follow-up run on an MI300X node.
@therock-pr-bot

therock-pr-bot Bot commented Jul 1, 2026

Copy link
Copy Markdown

✅ All Checks Passed — Ready for Review

Check Status Details
🌿 Branch Name ✅ Pass
📝 PR Title/Description ✅ Pass
Forbidden Files ✅ Pass
🧪 Unit Test ✅ Pass
🔎 pre-commit ✅ Pass
🚫 Draft PR 🔜 To Be Enabled
🚩 Feature Flag 🔜 To Be Enabled
📊 Code Coverage 🔜 To Be Enabled
🤖 therock-pr-bot ✅ Pass

🎉 All checks passed! This PR is ready for review.

📖 Need help? See the Policy FAQ for details on every check and how to fix failures.

@therock-pr-bot

therock-pr-bot Bot commented Jul 1, 2026

Copy link
Copy Markdown

🎉 All checks passed! This PR is ready for review.

@ozturkosu ozturkosu requested a review from yraparti July 1, 2026 05:22
@ozturkosu ozturkosu changed the title feat(ck_tile): gemm bridge TE<->Dispatcher for fp8/bf8/int8 (all layouts) feat(ck-tile): TE to dispatcher GEMM bridge for fp8/bf8/int8 (all layouts) Jul 1, 2026
@ozturkosu ozturkosu self-assigned this Jul 1, 2026
…ffle epilogue

unified_gemm_codegen forced CShuffleEpilogueProblem trailing template args
(false, 1, 1, DoubleSmemBuffer) for the gemm_universal cshuffle epilogue, while
Old-TE's gemm_universal_instance_builder stops at NumWaveGroups (letting the
epilogue defaults apply). For RowMajor-A those forced values equal the defaults
(parity), but for ColMajor-A and 4x1x1 block-maps they yield a higher-VGPR
kernel (120/128 vs Old-TE 92/100) -> lower occupancy -> 30-75% slower.

Drop the forced args so the bridge emits the same epilogue as Old-TE. On MI300X
all 18 affected fp8 stems recover: 50/51 formerly >15% (M/N/K sweep) rows now
within +/-15% (median 0.01%). multi_d variant left unchanged.
@ozturkosu

Copy link
Copy Markdown
Contributor Author

Fix: fp8 ColMajor-A / 4x1x1 perf regression — match Old-TE cshuffle epilogue (commit e3bde1b)

Symptom

On MI300X (gfx942), an apples-to-apples bridge-vs-Old-TE sweep (default_config, matched warmup=50/repeat=100/flush/rotating=1000 both sides, Old-TE built via develop CMake, flag-audit PASS — identical -mllvm codegen flags) showed a real, standalone-reproducible perf gap on a subset of fp8 configs:

  • fp8 RowMajor-A (rcr): at parity (median −0.3%).
  • fp8 ColMajor-A (ccr, crr) and the 4x1x1 block-map (even on RowMajor): −30% to −75% (bridge slower), persisted under standalone median-of-3.

Root cause

dispatcher/codegen/unified_gemm_codegen.py::_epilogue_code emitted the gemm_universal cshuffle epilogue with four forced trailing template args:

CShuffleEpilogueProblem<..., TransposeC, NumWaveGroups, false, 1, 1, DoubleSmemBuffer>

whereas Old-TE's gemm_universal_instance_builder stops at NumWaveGroups, letting the epilogue's computed defaults apply. For RowMajor-A those forced values happen to equal the defaults (→ parity); for ColMajor-A / 4x1x1 they don't, producing a heavier kernel. rocprof (same kernel symbol, same LDS/SGPR) confirmed the mechanism: bridge VGPR 120/128 vs Old-TE 92/100 → lower occupancy → the slowdown.

Fix

Drop the four forced args so the bridge emits the identical epilogue as Old-TE:

CShuffleEpilogueProblem<..., TransposeC, NumWaveGroups>

(1-line change; the multi_d variant is intentionally left unchanged.)

Result (rerun of all 51 formerly >15% rows, MI300X, same fair harness)

  • 50/51 now within ±15%; median gap 0.01%, mean −0.18%.
  • Examples (orig → fixed): ccr 128x128x256_4x1x1 @1024³ −71.7% → −1.1%; crr 128x128x128_1x4x1 @4096³ −43.3% → +1.1%; rrr 128x128x128_4x1x1 @2048³ −37.7% → −0.2%.
  • One residual: ccr 256x128x128_1x4x1 @1024³ improved −34.5% → −19.8% (its 2048³/4096³ shapes are now at parity) — likely a small-shape/wave-quant effect, still investigating.

int8 remains bridge-only (Old-TE builder rejects int8). A gfx950/MI350 parity run is in progress separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant