feat(ck-tile): TE to dispatcher GEMM bridge for fp8/bf8/int8 (all layouts)#8998
feat(ck-tile): TE to dispatcher GEMM bridge for fp8/bf8/int8 (all layouts)#8998ozturkosu wants to merge 16 commits into
Conversation
Consolidated, single-commit GEMM bridge routing the Tile Engine regular-GEMM sweep through the Dispatcher (codegen -> build -> runtime), so the Dispatcher is the single source of truth and the Tile Engine owns only the config search space and the benchmark loop. Mirrors the FMHA/Conv reference binding end to end. Scope: - Regular GEMM bridge: unified_gemm_codegen.py, gemm_ctypes_lib.cpp (flat extern "C" ABI, host-pointer model), gemm_utils.py (GemmKernelConfig with byte-exact .name, one-.so-per-kernel build), 3-phase TE driver + subprocess worker (gemm_full_benchmark.py / run_one_gemm_kernel.py). - Trait-derived registry KernelKey (replaces the hard-coded fp16/rcr key). - bf16 support and all four layouts (rcr/rrr/crr/ccr; row-major C only). - Tile Engine AMDGPU -mllvm codegen-flag parity + arch-validated tile filtering. - --verify fp32-reference correctness gate; multi-GPU fan-out. - Runnable example (examples/gemm/python/12_te_bridge.py) and parity/unit tests. - Removes the legacy standalone gemm_universal build path and the old test/ck_tile/gemm_tile_engine harness; promotes sweep configs to the op-root flat configs/ directory (fmha/grouped_conv convention). Validated on gfx942 / MI300X (fp16 + bf16, all four layouts) against an fp32 numpy reference via --verify.
The bridge dispatcher's tile-divisibility gate rejected any problem where
M % TileM != 0 for every layout, returning status -2 ("No suitable kernel")
at runtime even though the .so built fine. This wrongly excluded bf16 rcr/rrr
kernels with a non-power-of-two TileM (e.g. 192) on standard shapes like
1024^3 -- cases Old-TE compiles, runs, and verifies as correct.
Root cause: supports() was layout-blind, while the underlying
ck_tile::GemmKernel::IsSupportedArgument only constrains a dimension when an
operand whose inner axis is that dimension participates without padding:
RowMajor A -> K, ColMajor A -> M
RowMajor B -> N, ColMajor B -> K
RowMajor C -> N, ColMajor C -> M
So for rcr (RowMajor A & C) M is never gated, which is why Old-TE runs M=192
tiles on M-indivisible problems.
Make supports() compute require_m/n/k from the kernel key's A/B/C layouts so
it mirrors IsSupportedArgument exactly (also honoring k_batch in the K grain).
Anything it now lets through is still validated by the kernel's own
IsSupportedArgument inside launch(), so the bridge stays a strict functional
equivalent of Old-TE. Applied to both generated_tile_backend.hpp (the GEMM
.so path) and the sibling tile_backend.hpp.
Validated on gfx942 (MI300X): 85 previously status-2 rcr/rrr bf16 192-tile
.so now run at 1024^3 (Old-TE runs the same, verification correct); the 8
remaining rejects are tile N=192 cases that Old-TE also reports "Arguments
not supported" at N=1024 -- parity preserved in both directions.
…oding rcr dispatcher_initialize() in gemm_ctypes_lib.cpp hardcoded the KernelKey layout to rcr (RowMajor/ColMajor/RowMajor) for every kernel. Now that supports() is layout-aware, that wrong key layout makes the dispatcher reject valid problems: a crr kernel does not gate K (neither A=ColMajor nor B=RowMajor has K as its inner axis), but with a hardcoded rcr key supports() applies rcr's K-gate and returns status -2 for TileK=192 problems (e.g. crr 64x64x192 at 1024^3) that Old-TE compiles, runs, and verifies (~87 TFLOPS). Derive signature.layout_a/b/c from the force-included kernel's own ALayout/BLayout/CLayout types via std::is_same_v with tensor_layout::gemm::RowMajor. The key now matches the kernel, so the layout-aware gate is correct for all four layouts. Execution was already layout-correct (the kernel uses its own compile-time layouts); only the host-side selection metadata was wrong. Validated on gfx942 (MI300X): crr 64x64x192 now runs on the bridge (93 TFLOPS), restoring parity with Old-TE.
The >=20% bridge-vs-old-TE perf gaps in the parity sweep are a harness artifact: the sweep timed the bridge in-process but timed old-TE via its separate standalone benchmark binary, which runs the byte-identical kernel at a lower sustained SCLK. Measured through one harness the gap is <1%. ab_same_harness.py removed that artifact but hardcoded the old-TE header dir to fp16/rcr. Derive it per stem as <base>/<dtype>/<layout> so one run covers rcr/rrr/ccr/crr and fp16+bf16, add a --stems-file/--csv resume-aware sweep mode, and use the median (not max) per point.
For a full ~2000-stem sweep on a single GPU: batch all shapes into one worker call per side (5x fewer process startups), cache the compiled old-TE .so, and add a parallel --build-only pre-pass so hipcc compilation uses all CPU cores while GPU measurement stays serial.
… guard) The bridge-vs-old-TE A/B reported phantom regressions from two MEASUREMENT bugs, not real codegen gaps: - ab_same_harness.py built the old-TE side WITHOUT the TE codegen flags the bridge (and real old-TE's own CMake) use, so -enable-post-misched defaulted back on and old-TE ran ~10-40% faster -> the bridge looked regressed when it is at parity. Now both sides build with identical flags. - ab_efficient_sweep.py measured whatever libgemm_<stem>.so existed with no freshness check, so 3-day-old binaries built from an obsolete codegen showed up as -78%/+703% gaps. Added a guard: skip any .so older than its generated header (treated as missing) instead of reporting a phantom gap. With both fixes the 41 former >15% outlier stems measure within +/-10% (median +0.01%); no bridge codegen regression exists. Note: a separate, deliberately UNCOMMITTED perf change in gemm_utils.py (gate -enable-post-misched=0 on persistent) gives non-persistent large tiles ~9-40%; held back pending a broader persistent-kernel no-regression sweep.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
The bridge compiles each kernel .so with a hand-maintained hipcc flag list
(dispatcher/python/gemm_utils.py) that had drifted from Tile Engine's CMake
flags, so the bridge .so and the TE benchmark were not compiled apples-to-apples:
* MISSING -mllvm -amdgpu-coerce-illegal-types=1 (TE's CMakeLists.txt adds it
when the compiler accepts it; the bridge build never did)
* EXTRA -mllvm -enable-noalias-to-md-conversion=0 (not a TE GEMM flag; it
only appears in standalone CK examples/tests, never the TE gemm path)
Align the bridge's backend codegen flags with the exact set the TE
gemm_universal benchmark TU is built with. The coerce flag is added through a
cached hipcc probe that mirrors TE's check_cxx_compiler_flag, so the bridge stays
matched to TE on every toolchain (present where TE has it, skipped where TE's
CMake would skip it too).
The generated kernel source was already identical between the two engines; this
makes their compilation identical as well.
…y_diag/regression
Old-TE must remain until the dispatcher bridge implements every datatype Old-TE
supports, so revert the Old-TE removal from the bridge commit and re-wire its build:
* restore test/ck_tile/gemm_tile_engine/* (10 files)
* restore tile_engine/ops/gemm/gemm_universal/* (6 files: benchmark / instance
builder / profiler / single-bench / CMakeLists)
* re-add `add_subdirectory(gemm_universal EXCLUDE_FROM_ALL)` in
tile_engine/ops/gemm/CMakeLists.txt; restore test/ck_tile/CMakeLists.txt to the
develop state (gemm_tile_engine entry kept commented, as in develop)
Also drop the parity_diag/regression dev scripts that should not ship in the PR:
* dispatcher/parity_diag/regression/ab_efficient_sweep.py
* dispatcher/parity_diag/regression/ab_same_harness.py
…whitespace - Add AMD copyright/SPDX header to gemm_full_benchmark.py and run_one_gemm_kernel.py (CK requires a header on every source file). - Remove a trailing-whitespace blank line in generated_tile_backend.hpp that would trip the whitespace/clang-format CI gate.
….g. 192)
The CShuffle epilogue stores the accumulator back through LDS in power-of-two
MRepeat/NRepeat chunks, where MRepeat = tile_m / (wave_m * warp_tile_m) (and
likewise N). A tile whose per-wave repeat is not a power of two (or whose tile
dim is not divisible by wave*warp_tile) is mis-stored and produces numerically
WRONG results at runtime -- yet it still passes the ctypes validator and the
epilogue's static_asserts, so it compiles and silently returns garbage.
Observed on MI350 for tile_m=192 (MRepeat = 192/(2*32) = 3) and tile_n=192
(e.g. 64x192x64_1x4x1, 192 not divisible by 4*32): both verified incorrect
(fp32 reference, max_rel ~1.2-1.4) on the bridge AND Tile Engine, at every
shape including shapes divisible by 192. Power-of-two tiles (64/128/256) are
unaffected; a control 256-tile verifies cleanly (max_rel ~4e-4).
Add a validity gate in both tile-expansion paths:
* unified_gemm_codegen.py::_get_tile_configs (codegen CLI path)
* gemm_utils.py::expand_sweep (bridge .so build path; this path only ran the
ctypes validate_kernel_config, which does not catch this)
so invalid tiles are dropped instead of emitted/run. tile_k is unaffected (the
K reduction has no CShuffle store constraint).
…(all layouts) Adds the remaining data types Tile Engine's plain GEMM has MFMA warp tiles for beyond the fp16/bf16 surface of PR #8479: fp8 (E4M3) and bf8 (E5M2) accumulating into fp16, and int8 accumulating into int32 (gfx942). Covers all four A/B layout combinations per dtype (row-major C only, as ck_tile rejects column-major C). Codegen (codegen_common.py, unified_gemm_codegen.py): - add int32 to the CK / qualified / dispatcher dtype maps - get_output_dtype: int8 -> int32 (fp8/bf8 -> fp16 unchanged) - new get_acc_dtype: int8 -> int32, else fp32 - derive AccDataType, CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros and the registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32 Host harness (gemm_utils.py): - fp8/bf8 FNUZ (gfx942) uint8 codecs: exact decode (matches device fp8_t/bf8_t), nearest-representable saturating encode, mirroring the existing bf16 helper - GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype (fp16 for fp8/bf8, int32 for int8) - expand_sweep sets dtype_c/dtype_acc from the input dtype Tests: - test_gemm_utils.py: fp8/bf8 codec round-trip, format ranges, NaN/zero slots, saturation, byte size; output-dtype mapping (CPU-only) - test_gemm_parity.py: fp8/bf8/int8 cases with dtype-aware inputs, references and tolerances (int8 exact); GPU-gated like the existing fp16/bf16 cases GPU parity validation deferred to a follow-up run on an MI300X node.
✅ All Checks Passed — Ready for Review
📖 Need help? See the Policy FAQ for details on every check and how to fix failures. |
|
🎉 All checks passed! This PR is ready for review. |
…ffle epilogue unified_gemm_codegen forced CShuffleEpilogueProblem trailing template args (false, 1, 1, DoubleSmemBuffer) for the gemm_universal cshuffle epilogue, while Old-TE's gemm_universal_instance_builder stops at NumWaveGroups (letting the epilogue defaults apply). For RowMajor-A those forced values equal the defaults (parity), but for ColMajor-A and 4x1x1 block-maps they yield a higher-VGPR kernel (120/128 vs Old-TE 92/100) -> lower occupancy -> 30-75% slower. Drop the forced args so the bridge emits the same epilogue as Old-TE. On MI300X all 18 affected fp8 stems recover: 50/51 formerly >15% (M/N/K sweep) rows now within +/-15% (median 0.01%). multi_d variant left unchanged.
Fix: fp8 ColMajor-A /
|
Summary
Extends the Tile Engine ↔ Dispatcher GEMM bridge to the remaining data types TE's plain GEMM has MFMA warp tiles for, beyond the fp16/bf16 surface of #8479:
All four A/B layout combinations per dtype (row-major C only, matching #8479).
fp32/fp64are intentionally excluded — they appear in TE's dtype-string map but have no MFMA warp tiles inGEMM_WARP_TILE_SUPPORTED_COMBINATIONS, so no kernel can be generated/run.Depends on the fp16/bf16 bridge in #8997 (
users/muozturk/ck-tile/gemm-bridge-all-layout-bf16-fp16), which carries the bridge infrastructure and is not yet merged. This PR targetsdevelop, so until #8997 merges its diff also includes the base bridge changes; please merge #8997 first.Changes
codegen_common.py,unified_gemm_codegen.py): addint32to the dtype maps;get_output_dtypeint8→int32; newget_acc_dtype(int8→int32, else fp32); deriveAccDataType/CDataType, theGEMM_KEY_DTYPE_{C,ACC}macros, and the registrydtype_c/dtype_accfrom the dtype instead of hard-codingfloat/fp32.gemm_utils.py): fp8/bf8 FNUZ (gfx942) uint8 codecs — exact decode (matches devicefp8_t/bf8_t), nearest-representable saturating encode (same pattern as the existing bf16 helper);GpuGemmRunner.runencodes A/B and sizes the C buffer per dtype;expand_sweepsetsdtype_c/dtype_acc.test_gemm_utils.pyadds CPU-only fp8/bf8 codec + output-dtype tests (all green);test_gemm_parity.pyadds fp8/bf8/int8 cases with dtype-aware inputs/references/tolerances (int8 is bit-exact), GPU-gated like the existing cases.Verification done
test_gemm_utils.py+test_codegen_common.py: 54 passed (CPU).ADataType/CDataType/AccDataTypeandGEMM_KEY_*macros are correct (int8→int32_t acc/C; fp8→fp16_t C).test_gemm_parity.pycollects 60 cases and skips cleanly without a GPU.test_examples_integration/test_grouped_conv_codegen/test_library_cachingare pre-existing (verified identical on the base branch; they require a built dispatcher.a/ GPU).Test plan
develop.python3 tests/test_gemm_parity.pyand confirm fp8/bf8/int8 parity; tune the fp8/bf8 tolerances if needed (current values are first-cut headroom).