From f8506a4a6f0dafc5eafc31131914f02def309bd6 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Tue, 16 Jun 2026 00:25:33 -0400 Subject: [PATCH 01/16] [CK_TILE] TE -> Dispatcher GEMM bridge: all layouts + fp16/bf16 Consolidated, single-commit GEMM bridge routing the Tile Engine regular-GEMM sweep through the Dispatcher (codegen -> build -> runtime), so the Dispatcher is the single source of truth and the Tile Engine owns only the config search space and the benchmark loop. Mirrors the FMHA/Conv reference binding end to end. Scope: - Regular GEMM bridge: unified_gemm_codegen.py, gemm_ctypes_lib.cpp (flat extern "C" ABI, host-pointer model), gemm_utils.py (GemmKernelConfig with byte-exact .name, one-.so-per-kernel build), 3-phase TE driver + subprocess worker (gemm_full_benchmark.py / run_one_gemm_kernel.py). - Trait-derived registry KernelKey (replaces the hard-coded fp16/rcr key). - bf16 support and all four layouts (rcr/rrr/crr/ccr; row-major C only). - Tile Engine AMDGPU -mllvm codegen-flag parity + arch-validated tile filtering. - --verify fp32-reference correctness gate; multi-GPU fan-out. - Runnable example (examples/gemm/python/12_te_bridge.py) and parity/unit tests. - Removes the legacy standalone gemm_universal build path and the old test/ck_tile/gemm_tile_engine harness; promotes sweep configs to the op-root flat configs/ directory (fmha/grouped_conv convention). Validated on gfx942 / MI300X (fp16 + bf16, all four layouts) against an fp32 numpy reference via --verify. --- .../bindings/ctypes/gemm_ctypes_lib.cpp | 83 +- .../codegen/unified_gemm_codegen.py | 81 +- .../examples/gemm/python/12_te_bridge.py | 279 ++++++ .../backends/generated_tile_backend.hpp | 51 +- .../include/ck_tile/dispatcher/kernel_key.hpp | 5 + .../parity_diag/regression/ab_same_harness.py | 128 +++ .../dispatcher/python/ctypes_utils.py | 2 +- .../dispatcher/python/gemm_utils.py | 825 ++++++++++++++++++ .../dispatcher/tests/test_gemm_parity.py | 265 ++++++ .../dispatcher/tests/test_gemm_utils.py | 132 +++ .../test/ck_tile/CMakeLists.txt | 3 - .../ck_tile/gemm_tile_engine/CMakeLists.txt | 348 -------- .../test/ck_tile/gemm_tile_engine/README.md | 85 -- .../comprehensive_coverage_config.json | 37 - .../configs/large_datatype_config.json | 34 - .../configs/padding_coverage_config.json | 34 - .../configs/quick_coverage_config.json | 34 - .../configs/simple_test_config.json | 34 - .../configs/small_datatype_config.json | 35 - .../gemm_tile_engine/extract_test_params.py | 74 -- .../gemm_tile_engine/test_gemm_simple.cpp | 241 ----- .../tile_engine/ops/gemm/CMakeLists.txt | 5 +- .../tile_engine/ops/gemm/README.md | 167 +++- .../configs/default_ci_config.json | 6 +- .../configs/default_config.json | 0 .../ops/gemm/configs/example_problems.json | 9 + .../configs/user_provided_config.json | 0 .../ops/gemm/gemm_full_benchmark.py | 504 +++++++++++ .../ops/gemm/gemm_universal/CMakeLists.txt | 338 ------- .../gemm_universal_benchmark.hpp | 73 -- .../gemm_universal_benchmark.py | 149 ---- .../gemm_universal_benchmark_single.cpp | 102 --- .../gemm_universal_instance_builder.py | 344 -------- .../gemm_universal_profiler.hpp | 147 ---- .../ops/gemm/run_one_gemm_kernel.py | 140 +++ 35 files changed, 2642 insertions(+), 2152 deletions(-) create mode 100644 projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py create mode 100644 projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py create mode 100644 projects/composablekernel/dispatcher/python/gemm_utils.py create mode 100644 projects/composablekernel/dispatcher/tests/test_gemm_parity.py create mode 100644 projects/composablekernel/dispatcher/tests/test_gemm_utils.py delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py delete mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp rename projects/composablekernel/tile_engine/ops/gemm/{gemm_universal => }/configs/default_ci_config.json (98%) rename projects/composablekernel/tile_engine/ops/gemm/{gemm_universal => }/configs/default_config.json (100%) create mode 100644 projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json rename projects/composablekernel/tile_engine/ops/gemm/{gemm_universal => }/configs/user_provided_config.json (100%) create mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py delete mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt delete mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp delete mode 100755 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py delete mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp delete mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py delete mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp create mode 100644 projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py diff --git a/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp b/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp index 85c0c2f2c13a..57b98a8df135 100644 --- a/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp +++ b/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp @@ -65,8 +65,56 @@ int dispatcher_initialize() return 0; // Already initialized } - // Create kernel key from the force-included kernel header + // Create kernel key from the force-included kernel header. + // + // The GEMM_KEY_* macros are emitted by the codegen into the force-included + // header (see unified_gemm_codegen.py, CK_TILE_SINGLE_KERNEL_INCLUDE block). + // Building the key from them makes the registry entry truthful: it reflects + // THIS kernel's real dtypes/layouts/tile/traits instead of a hard-coded + // fp16/rcr/128x128x32 default. Enum fields use the string_to_* helpers from + // kernel_key.hpp, whose accepted strings match the codegen's emitted values + // byte-for-byte. KernelKey key; +#ifdef GEMM_KEY_DTYPE_A + key.signature.dtype_a = string_to_dtype(GEMM_KEY_DTYPE_A); + key.signature.dtype_b = string_to_dtype(GEMM_KEY_DTYPE_B); + key.signature.dtype_c = string_to_dtype(GEMM_KEY_DTYPE_C); + key.signature.dtype_acc = string_to_dtype(GEMM_KEY_DTYPE_ACC); + key.signature.layout_a = string_to_layout(GEMM_KEY_LAYOUT_A); + key.signature.layout_b = string_to_layout(GEMM_KEY_LAYOUT_B); + key.signature.layout_c = string_to_layout(GEMM_KEY_LAYOUT_C); + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = (GEMM_KEY_GROUPED != 0); + key.signature.split_k = GEMM_KEY_SPLIT_K; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {GEMM_KEY_TILE_M, GEMM_KEY_TILE_N, GEMM_KEY_TILE_K}; + key.algorithm.wave_shape = {GEMM_KEY_WAVE_M, GEMM_KEY_WAVE_N, GEMM_KEY_WAVE_K}; + key.algorithm.warp_tile_shape = {GEMM_KEY_WARP_TILE_M, GEMM_KEY_WARP_TILE_N, GEMM_KEY_WARP_TILE_K}; + key.algorithm.pipeline = string_to_pipeline(GEMM_KEY_PIPELINE); + key.algorithm.scheduler = string_to_scheduler(GEMM_KEY_SCHEDULER); + key.algorithm.epilogue = string_to_epilogue(GEMM_KEY_EPILOGUE); + key.algorithm.block_size = GEMM_KEY_BLOCK_SIZE; + key.algorithm.double_buffer = (GEMM_KEY_DOUBLE_BUFFER != 0); + key.algorithm.persistent = (GEMM_KEY_PERSISTENT != 0); + key.algorithm.preshuffle = (GEMM_KEY_PRESHUFFLE != 0); + key.algorithm.transpose_c = (GEMM_KEY_TRANSPOSE_C != 0); + key.algorithm.num_wave_groups = GEMM_KEY_NUM_WAVE_GROUPS; + // pad_m/n/k participate in both the key's hash/equality and the kernel + // name, so they must be derived from the codegen macros too -- otherwise a + // kernel built with padding disabled would register under a key claiming + // pad=true and disagree with its own name. + key.algorithm.pad_m = (GEMM_KEY_PAD_M != 0); + key.algorithm.pad_n = (GEMM_KEY_PAD_N != 0); + key.algorithm.pad_k = (GEMM_KEY_PAD_K != 0); + key.gfx_arch = GFX_ARCH; +#else + // Fallback default for headers generated before GEMM_KEY_* macros existed + // (fp16 / rcr / compv4-cshuffle-intrawave, 128x128x32). The macro path + // above is the source of truth for any freshly generated kernel. key.signature.dtype_a = DataType::FP16; key.signature.dtype_b = DataType::FP16; key.signature.dtype_c = DataType::FP16; @@ -95,6 +143,7 @@ int dispatcher_initialize() key.algorithm.transpose_c = false; key.algorithm.num_wave_groups = 1; key.gfx_arch = GFX_ARCH; +#endif // GEMM_KEY_DTYPE_A // Register kernel using types from force-included header auto kernel = @@ -310,10 +359,40 @@ int dispatcher_run_gemm( } /** - * Get kernel information + * Get kernel information (legacy single-kernel ABI). + * + * Returns the compile-time KERNEL_NAME of the force-included kernel header. + * Kept for backward compatibility with one-kernel-per-.so callers. */ const char* dispatcher_get_kernel_name() { return KERNEL_NAME; } +/** + * Get the name of the kernel at a given registry index (multi-kernel ABI). + * + * Mirrors the conv/fmha ctypes libs: copies the index-th registered kernel's + * name into the caller-provided buffer so one .so can report a whole batch and + * be selected by name at runtime. Returns 0 on success, -1 on bad args or + * out-of-range index. + */ +int dispatcher_get_kernel_name_at(int index, char* buffer, int buffer_size) +{ + if(!buffer || buffer_size <= 0) + { + return -1; + } + + auto kernels = Registry::instance().get_all(); + if(index < 0 || index >= static_cast(kernels.size())) + { + return -1; + } + + std::string name = kernels[index]->get_name(); + std::strncpy(buffer, name.c_str(), static_cast(buffer_size) - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + /** * Initialize dispatcher (alias) */ diff --git a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py index c0fb08aa4436..ec525ddd5c4c 100755 --- a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py +++ b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py @@ -520,6 +520,43 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]}; using AccDataType = float; + +// KernelKey field descriptors for the force-included kernel. +// The ctypes library builds the registry KernelKey from these so the +// registered entry reflects this kernel's real traits (not a hard-coded +// fp16/rcr default). Enum-valued fields are emitted as the exact strings +// consumed by string_to_dtype/layout/pipeline/scheduler/epilogue in +// kernel_key.hpp; shape/flag fields are emitted as numeric/0-1 literals. +#define GEMM_KEY_DTYPE_A "{self.datatype}" +#define GEMM_KEY_DTYPE_B "{self.datatype}" +#define GEMM_KEY_DTYPE_C "{output_dtype}" +#define GEMM_KEY_DTYPE_ACC "fp32" +#define GEMM_KEY_LAYOUT_A "{self.layout[0]}" +#define GEMM_KEY_LAYOUT_B "{self.layout[1]}" +#define GEMM_KEY_LAYOUT_C "{self.layout[2]}" +#define GEMM_KEY_PIPELINE "{tr.pipeline}" +#define GEMM_KEY_SCHEDULER "{tr.scheduler}" +#define GEMM_KEY_EPILOGUE "{tr.epilogue}" +#define GEMM_KEY_TILE_M {t.tile_m} +#define GEMM_KEY_TILE_N {t.tile_n} +#define GEMM_KEY_TILE_K {t.tile_k} +#define GEMM_KEY_WAVE_M {t.warp_m} +#define GEMM_KEY_WAVE_N {t.warp_n} +#define GEMM_KEY_WAVE_K {t.warp_k} +#define GEMM_KEY_WARP_TILE_M {t.warp_tile_m} +#define GEMM_KEY_WARP_TILE_N {t.warp_tile_n} +#define GEMM_KEY_WARP_TILE_K {t.warp_tile_k} +#define GEMM_KEY_BLOCK_SIZE {config.block_size} +#define GEMM_KEY_NUM_WAVE_GROUPS {config.num_wave_groups} +#define GEMM_KEY_PAD_M {int(tr.pad_m)} +#define GEMM_KEY_PAD_N {int(tr.pad_n)} +#define GEMM_KEY_PAD_K {int(tr.pad_k)} +#define GEMM_KEY_PERSISTENT {int(tr.persistent)} +#define GEMM_KEY_DOUBLE_BUFFER {int(tr.pipeline == "compv4" or tr.pipeline == "preshufflev2")} +#define GEMM_KEY_PRESHUFFLE {int(config.preshuffle)} +#define GEMM_KEY_TRANSPOSE_C 0 +#define GEMM_KEY_GROUPED 0 +#define GEMM_KEY_SPLIT_K 1 #endif // CK_TILE_SINGLE_KERNEL_INCLUDE """ @@ -1030,9 +1067,15 @@ def _get_configs_for_variant(self, variant: GemmVariant) -> List[KernelConfig]: trait_configs = self._get_trait_configs() for tile, trait in itertools.product(tile_configs, trait_configs): - # Perform variant-specific architecture validation + # Perform variant-specific architecture validation against the + # trait's ACTUAL pipeline/scheduler (not a hard-coded compv4). if self.arch_filter and HAS_ARCH_FILTER: - if not self._is_tile_arch_valid(tile, variant): + if not self._is_tile_arch_valid( + tile, + variant, + pipeline=trait.pipeline, + scheduler=trait.scheduler, + ): continue if variant == GemmVariant.STANDARD: @@ -1105,9 +1148,21 @@ def _get_tile_configs(self) -> List[TileConfig]: rejected_count += 1 continue - # Architecture-specific validation + # Architecture-specific validation. This is a pre-filter run before + # tiles are paired with traits, so keep a tile if it is legal under + # ANY configured pipeline/scheduler; the precise per-trait check + # happens later in _get_configs_for_variant. Filtering here with a + # single hard-coded pipeline (compv4) wrongly dropped tiles that are + # legal under mem/compv3. if self.arch_filter and HAS_ARCH_FILTER: - if not self._is_tile_arch_valid(tile): + trait_cfg = self.config.get("trait_config", {}) + pipelines = trait_cfg.get("pipeline") or ["compv4"] + schedulers = trait_cfg.get("scheduler") or ["intrawave"] + if not any( + self._is_tile_arch_valid(tile, pipeline=pl, scheduler=sc) + for pl in pipelines + for sc in schedulers + ): rejected_count += 1 continue @@ -1119,13 +1174,23 @@ def _get_tile_configs(self) -> List[TileConfig]: return configs def _is_tile_arch_valid( - self, tile: TileConfig, variant: GemmVariant = None + self, + tile: TileConfig, + variant: GemmVariant = None, + pipeline: str = None, + scheduler: str = None, ) -> bool: """Check if tile configuration is valid for target architecture Args: tile: Tile configuration to validate variant: GEMM variant (affects operator-specific constraints) + pipeline: Trait pipeline to validate against. Pass the config's + actual pipeline -- omitting it falls back to ``compv4``, whose + MFMA constraints are stricter than ``mem``/``compv3`` and would + wrongly reject tiles that are legal under those pipelines. + scheduler: Trait scheduler to validate against (defaults to + ``intrawave`` for the same reason). """ if not self.arch_filter or not HAS_ARCH_FILTER: return True @@ -1146,8 +1211,10 @@ def _is_tile_arch_valid( # Map GEMM variant to operator type for validation operator = None - pipeline = "compv4" # Default - scheduler = "intrawave" # Default + if pipeline is None: + pipeline = "compv4" # Default (representative compute pipeline) + if scheduler is None: + scheduler = "intrawave" # Default if OperatorType is not None and variant is not None: variant_to_operator = { diff --git a/projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py b/projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py new file mode 100644 index 000000000000..53a44060ef35 --- /dev/null +++ b/projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 12: Tile Engine -> Dispatcher bridge (gallery) + +Unlike examples 01-11 (which drive the Dispatcher's native ctypes Registry), +this example exercises the *Tile Engine -> Dispatcher bridge* in +``dispatcher/python/gemm_utils.py``. The bridge is the path the Tile Engine +itself uses: one common ``GemmKernelConfig`` feeds codegen, force-include +compile, and a flat extern "C" ABI, and ``GpuGemmRunner`` runs the resulting +.so against a NumPy reference. + +It is a small gallery of three demos that together cover the surface the bridge +gained over its original fp16/rcr-only slice: + + matrix every (dtype, layout) pair the universal GEMM supports -- fp16 and + bf16 across the four row/col A/B combinations (row-major C only). + shapes why padding matters: a padded kernel accepts an awkward, non-tile- + aligned problem (M, N not divisible by the tile) while the equivalent + no-pad kernel rejects it -- the same selection rule the Tile Engine + sees when it sweeps pad on/off. + sweep the "search space" idea: one fixed signature, several *algorithms* + (tile / wave / pipeline), built and ranked by measured TFLOPS -- a + miniature of what the Tile Engine driver does at scale. + +(Kernel *variants* such as Stream-K and grouped GEMM ride the same bridge but +live on separate branches; this example stays within the regular GEMM stack.) + +Usage: + python3 12_te_bridge.py # runs all three demos + python3 12_te_bridge.py --demo matrix + python3 12_te_bridge.py --demo shapes + python3 12_te_bridge.py --demo sweep + python3 12_te_bridge.py --size 1024 --rtol 2e-2 --arch gfx950 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np # noqa: E402 + +from gemm_utils import ( # noqa: E402 + GemmKernelConfig, + GemmProblem, + GpuGemmRunner, + setup_multiple_gemm_dispatchers, +) +from ctypes_utils import detect_gpu_arch # noqa: E402 + +# A single algorithm known to compile and run on gfx942. Only the Signature +# (dtype + layout) varies in the matrix demo; the Algorithm is held fixed so the +# demo isolates the bridge's dtype/layout generality. +_ALGO = dict( + tile_m=64, tile_n=64, tile_k=64, + wave_m=4, wave_n=1, wave_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16, + pipeline="compv3", scheduler="intrawave", epilogue="cshuffle", + pad_m=False, pad_n=False, pad_k=False, +) + +# (dtype, layout) pairs. Column-major C (e.g. rcc) is rejected at build by the +# universal GEMM, so every case keeps row-major C -- which leaves exactly four +# A/B combinations (rcr/rrr/ccr/crr). Both dtypes cover all four. +_CASES = [ + ("fp16", "rcr"), ("fp16", "rrr"), ("fp16", "ccr"), ("fp16", "crr"), + ("bf16", "rcr"), ("bf16", "rrr"), ("bf16", "ccr"), ("bf16", "crr"), +] + +_LAYOUT_WORD = {"r": "row", "c": "col"} + + +def _emulate(x: np.ndarray, dtype: str) -> np.ndarray: + """Round fp32 inputs to the kernel's storage dtype so the CPU reference + matches what the GPU actually multiplies.""" + if dtype == "bf16": + u32 = np.ascontiguousarray(x, dtype=np.float32).view(np.uint32) + rounded = (u32 + ((u32 >> 16) & 1) + np.uint32(0x7FFF)) >> 16 + return (rounded.astype(np.uint32) << 16).view(np.float32) + return x.astype(np.float16).astype(np.float32) + + +def _max_rel(out: np.ndarray, ref: np.ndarray) -> float: + # Global relative error (normalize by the largest reference magnitude): + # per-element ratios explode on the near-zero entries that K-length + # accumulation of zero-mean data produces, so they are not meaningful. + denom = float(np.max(np.abs(ref))) + 1e-12 + return float(np.max(np.abs(out - ref))) / denom + + +def _config(dtype: str, layout: str, arch: str, **algo) -> GemmKernelConfig: + la, lb, lc = layout + return GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + layout_a=_LAYOUT_WORD[la], layout_b=_LAYOUT_WORD[lb], layout_c=_LAYOUT_WORD[lc], + gfx_arch=arch, **(algo or _ALGO), + ) + + +def _reference(A, B, dtype): + # Emulate both input quantization (A,B stored as dtype) and the output store + # (GPU writes C back as dtype_c), so round on both ends before comparing. + return _emulate(_emulate(A, dtype) @ _emulate(B, dtype), dtype) + + +# --------------------------------------------------------------------------- +# Demo 1: dtype x layout matrix +# --------------------------------------------------------------------------- +def demo_matrix(size, rtol, arch): + print(f"\n[matrix] dtype x layout, M=N=K={size}, rtol={rtol:g}") + problem = GemmProblem(M=size, N=size, K=size) + configs = [_config(dt, lay, arch) for dt, lay in _CASES] + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + + rng = np.random.default_rng(42) + A = (rng.standard_normal((problem.M, problem.K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((problem.K, problem.N)) * 0.1).astype(np.float32) + + n_pass = 0 + for (dtype, layout), so in zip(_CASES, so_paths): + tag = f"{dtype}/{layout}" + if so is None: + print(f" {tag:10s} BUILD FAILED") + continue + result = GpuGemmRunner(lib_path=so).run(A, B, problem) + if not result.success: + print(f" {tag:10s} RUN FAILED (status {result.status})") + continue + mr = _max_rel(result.output, _reference(A, B, dtype)) + ok = mr <= rtol + n_pass += ok + print(f" {tag:10s} tflops={result.tflops:7.1f} max_rel={mr:.2e} " + f"{'PASS' if ok else 'FAIL'}") + print(f" -> {n_pass}/{len(_CASES)} passed") + return n_pass, len(_CASES) + + +# --------------------------------------------------------------------------- +# Demo 2: padding vs an awkward (non-tile-aligned) shape +# --------------------------------------------------------------------------- +def demo_shapes(rtol, arch): + print("\n[shapes] padding lets a kernel accept a non-tile-aligned problem") + # 128-tile kernels; awkward M, N do not divide 128 (K stays divisible by 8 + # for the fp16 vectorized reduction load). + algo_pad = dict( + tile_m=128, tile_n=128, tile_k=32, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave", epilogue="cshuffle", + pad_m=True, pad_n=True, pad_k=True, + ) + algo_nopad = dict(algo_pad, pad_m=False, pad_n=False, pad_k=False) + + cfg_pad = _config("fp16", "rcr", arch, **algo_pad) + cfg_nopad = _config("fp16", "rcr", arch, **algo_nopad) + so_pad, so_nopad = setup_multiple_gemm_dispatchers([cfg_pad, cfg_nopad], verbose=False) + + M, N, K = 257, 129, 512 # awkward: 257, 129 not divisible by 128 + problem = GemmProblem(M=M, N=N, K=K) + rng = np.random.default_rng(7) + A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + ref = _reference(A, B, "fp16") + + print(f" awkward problem M={M} N={N} K={K} (M,N not divisible by tile 128)") + n_pass = 0 + expectations = [("padded", so_pad, True), ("no-pad", so_nopad, False)] + for label, so, should_pass in expectations: + if so is None: + print(f" {label:8s} BUILD FAILED") + continue + result = GpuGemmRunner(lib_path=so).run(A, B, problem) + if result.success: + mr = _max_rel(result.output, ref) + accepted = mr <= rtol + outcome = f"ACCEPTED tflops={result.tflops:7.1f} max_rel={mr:.2e}" + else: + accepted = False + # status -2 == select_kernel found no kernel whose tiling fits. + outcome = f"REJECTED (status {result.status})" + # "Correct" = the no-pad kernel rejects and the padded one accepts. + correct = accepted == should_pass + n_pass += correct + print(f" {label:8s} {outcome:42s} {'as expected' if correct else 'UNEXPECTED'}") + print(f" -> {n_pass}/2 behaved as expected (padded accepts, no-pad rejects)") + return n_pass, 2 + + +# --------------------------------------------------------------------------- +# Demo 3: algorithm sweep over one fixed signature +# --------------------------------------------------------------------------- +def demo_sweep(size, rtol, arch): + print(f"\n[sweep] fixed fp16/rcr signature, several algorithms, M=N=K={size}") + # A handful of distinct algorithms (the Tile Engine sweeps thousands of + # these). Each is a different tile / wave / pipeline point in the search + # space; padding is on so any size is accepted. + algos = [ + dict(tile_m=128, tile_n=128, tile_k=32, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, pipeline="compv4"), + dict(tile_m=256, tile_n=128, tile_k=32, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, pipeline="compv4"), + dict(tile_m=128, tile_n=128, tile_k=64, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, pipeline="compv3"), + dict(tile_m=64, tile_n=64, tile_k=64, wave_m=4, wave_n=1, wave_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16, pipeline="compv3"), + ] + common = dict(scheduler="intrawave", epilogue="cshuffle", + pad_m=True, pad_n=True, pad_k=True) + configs = [_config("fp16", "rcr", arch, **dict(a, **common)) for a in algos] + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + + problem = GemmProblem(M=size, N=size, K=size) + rng = np.random.default_rng(123) + A = (rng.standard_normal((size, size)) * 0.1).astype(np.float32) + B = (rng.standard_normal((size, size)) * 0.1).astype(np.float32) + ref = _reference(A, B, "fp16") + + rows = [] + for cfg, so in zip(configs, so_paths): + label = f"{cfg.tile_m}x{cfg.tile_n}x{cfg.tile_k}/{cfg.pipeline}" + if so is None: + rows.append((label, None, None)) + continue + result = GpuGemmRunner(lib_path=so).run(A, B, problem) + if not result.success: + rows.append((label, None, None)) + continue + rows.append((label, result.tflops, _max_rel(result.output, ref))) + + ranked = sorted((r for r in rows if r[1] is not None), + key=lambda r: r[1], reverse=True) + print(f" {'rank':>4} {'algorithm':<24} {'tflops':>9} {'max_rel':>10}") + for i, (label, tflops, mr) in enumerate(ranked, 1): + print(f" {i:>4} {label:<24} {tflops:>9.1f} {mr:>10.2e}") + for label, tflops, _ in rows: + if tflops is None: + print(f" {'-':>4} {label:<24} {'BUILD/RUN FAILED':>20}") + if ranked: + print(f" -> fastest: {ranked[0][0]} at {ranked[0][1]:.1f} TFLOPS") + n_ok = sum(1 for _, _, mr in ranked if mr <= rtol) + return n_ok, len(configs) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Tile Engine -> Dispatcher bridge example (gallery)", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--demo", choices=["matrix", "shapes", "sweep", "all"], + default="all", help="which demo to run (default: all)") + parser.add_argument("--size", type=int, default=512, help="M=N=K (default 512)") + parser.add_argument("--rtol", type=float, default=2e-2, + help="relative tolerance (default 2e-2)") + parser.add_argument("--arch", default=detect_gpu_arch(), + help="GPU target arch (default: auto-detected via rocminfo)") + args = parser.parse_args() + + demos = ["matrix", "shapes", "sweep"] if args.demo == "all" else [args.demo] + total_pass = 0 + total = 0 + for d in demos: + if d == "matrix": + p, t = demo_matrix(args.size, args.rtol, args.arch) + elif d == "shapes": + p, t = demo_shapes(args.rtol, args.arch) + else: + p, t = demo_sweep(args.size, args.rtol, args.arch) + total_pass += p + total += t + + print(f"\n{total_pass}/{total} checks passed across {len(demos)} demo(s)") + return 0 if total_pass == total else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp index be22d94b3331..0b072d761742 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace ck_tile { namespace dispatcher { @@ -101,16 +103,30 @@ class GeneratedTileKernelInstance : public KernelInstance problem.N // stride_E/C (row-major C: stride = N) ); + // Benchmark parameters. Defaults mirror old Tile Engine's + // gemm_common.hpp (warmup=50, repeat=100, flush_cache=true, + // rotating_count=1000), and a generous warmup keeps the GPU clock + // ramped. NOTE: matching these knobs does NOT by itself make + // bridge-vs-old-TE numbers comparable -- the byte-identical kernel + // measures ~18-20% faster here than through old TE's *standalone + // benchmark binary* at e.g. 1024^3/compv4, purely because that + // separate process runs the kernel at a lower sustained SCLK (+ more + // memory-stall cycles), not because of any bench knob, compiler, or + // kernel difference (rocprof-confirmed). For an honest A/B, measure + // BOTH kernels through the SAME harness (build the old-TE kernel into a + // .so and run it via run_one_gemm_kernel.py) -- the gap then collapses + // to ~1%. Each knob is env-overridable so a caller can match another + // harness without recompiling. const bool bench = this->benchmarking_; ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = reinterpret_cast(stream); stream_cfg.time_kernel_ = bench; stream_cfg.log_level_ = 0; - stream_cfg.cold_niters_ = bench ? 5 : 0; - stream_cfg.nrepeat_ = bench ? 10 : 1; + stream_cfg.cold_niters_ = bench ? env_int("CK_TILE_BENCH_WARMUP", 50) : 0; + stream_cfg.nrepeat_ = bench ? env_int("CK_TILE_BENCH_REPEAT", 100) : 1; stream_cfg.is_gpu_timer_ = bench; - stream_cfg.flush_cache_ = false; - stream_cfg.rotating_count_ = 1; + stream_cfg.flush_cache_ = bench && env_bool("CK_TILE_BENCH_FLUSH", true); + stream_cfg.rotating_count_ = bench ? env_int("CK_TILE_BENCH_ROTATING", 1000) : 1; // Call the generated kernel's launch method return SelectedKernel::launch(args, stream_cfg); @@ -134,6 +150,33 @@ class GeneratedTileKernelInstance : public KernelInstance } private: + // Read an integer benchmark knob from the environment, falling back to + // `fallback` when unset or unparseable. + static int env_int(const char* name, int fallback) + { + const char* v = std::getenv(name); + if(v == nullptr || *v == '\0') + return fallback; + char* end = nullptr; + const long out = std::strtol(v, &end, 10); + if(end == v) + return fallback; + return static_cast(out); + } + + // Read a boolean benchmark knob ("0"/"false"/"off", any case => false, else true). + static bool env_bool(const char* name, bool fallback) + { + const char* v = std::getenv(name); + if(v == nullptr || *v == '\0') + return fallback; + std::string s(v); + for(char& c : s) + if(c >= 'A' && c <= 'Z') + c = static_cast(c - 'A' + 'a'); + return !(s == "0" || s == "false" || s == "off"); + } + KernelKey key_; std::string name_; }; diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp index 24b20ecd9b8f..24d67cbf437a 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp @@ -369,6 +369,11 @@ inline Scheduler string_to_scheduler(const std::string& str) { if(str == "auto") return Scheduler::Auto; + // Preshuffle kernels emit "default"; the codegen maps it to Scheduler::Auto + // (see codegen_common.py SCHEDULER_TO_DISPATCHER), so mirror that here + // instead of silently falling through to Intrawave. + if(str == "default") + return Scheduler::Auto; if(str == "intrawave") return Scheduler::Intrawave; if(str == "interwave") diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py new file mode 100644 index 000000000000..04e89e84d08b --- /dev/null +++ b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Apples-to-apples GEMM A/B: bridge kernel vs old-TE kernel, ONE harness. + +Why this exists +--------------- +The earlier sweep (allsweep6144rcrfp16.py) compared the bridge's dispatcher +measurement against old TE's *standalone benchmark binary* +(benchmark_gemm_universal_). That comparison is NOT apples-to-apples: +the device kernel is byte-identical, yet old TE's standalone binary reports +~18-20% lower TFLOPS at e.g. 1024^3 / compv4. rocprof shows the identical +kernel genuinely runs longer in that process -- ~+8% cycles plus a lower +sustained SCLK -- a power/clock + execution-environment artifact of that +binary, NOT a bridge speedup, compiler difference, or kernel difference. +(See diagnose.md sec.4.) + +This harness removes the artifact: it builds the OLD-TE kernel into a .so from +old TE's own generated header and runs BOTH the bridge kernel and the old-TE +kernel through the SAME worker (run_one_gemm_kernel.py). Measured this way the +gap collapses to ~1%, which is the honest result. + +Usage: + python3 ab_same_harness.py # default kernel list + shapes + python3 ab_same_harness.py [...] +""" +import json +import os +import subprocess +import sys +from pathlib import Path + +# composablekernel root: .../composablekernel/dispatcher/parity_diag/regression/ +ROOT = Path(__file__).resolve().parents[3] +DISP = ROOT / "dispatcher" +GEN = DISP / "build" / "generated_kernels" +SRC = DISP / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" +STATIC = DISP / "build" / "libck_tile_dispatcher.a" +BR_SO_DIR = DISP / "build" / "examples" +WORKER = ROOT / "tile_engine/ops/gemm/run_one_gemm_kernel.py" +# old-TE generated single-kernel headers. Override with OLD_TE_GEN; the default +# points at a sibling develop-parity worktree under the rocm-libraries root. +OLD_GEN = Path(os.environ.get( + "OLD_TE_GEN", + str(ROOT.parents[1] / ".claude/worktrees/develop-parity" + "/projects/composablekernel/build/tile_engine/ops/gemm/gemm_universal/fp16/rcr"), +)) +OUT = DISP / "parity_diag" / "regression" / "_ab_same_harness_build" +ARCH = os.environ.get("GFX_ARCH", "gfx942") +DEVICE = os.environ.get("PARITY_DEVICE", "0") +REPEATS = int(os.environ.get("AB_REPEATS", "3")) + +SHAPES = [(512, 512, 512), (1024, 1024, 1024), (2048, 2048, 2048), + (1024, 512, 256), (4096, 4096, 4096)] + +DEFAULT_STEMS = [ + "fp16_rcr_compv4_default_intrawave_False_False_False_False_64x128x64_2x2x1_32x32x16", + "fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_64x128x64_1x4x1_32x32x16", + "fp16_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_4x1x1_32x32x16", +] + +PYPATH = os.pathsep.join([str(DISP / "python"), str(ROOT / "tile_engine/ops/gemm")]) + + +def build_old_so(stem: str) -> Path | None: + """Compile old TE's generated kernel header into a bridge-loadable .so.""" + hdr = OLD_GEN / f"gemm_universal_single_{stem}.hpp" + if not hdr.exists(): + return None + OUT.mkdir(parents=True, exist_ok=True) + obj = OUT / f"{stem}.o" + lib = OUT / f"libold_{stem}.so" + common = [ + "-fPIC", "-O3", + f"-I{DISP / 'include'}", f"-I{ROOT / 'include'}", f"-I{ROOT}", f"-I{GEN}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", f"-include{hdr}", "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={ARCH}", f'-DGFX_ARCH="{ARCH}"', + "-Wno-undefined-func-template", "-Wno-float-equal", + ] + cc = subprocess.run(["/opt/rocm/bin/hipcc", "-c", *common, str(SRC), "-o", str(obj)], + capture_output=True) + if cc.returncode != 0: + return None + ln = subprocess.run(["/opt/rocm/bin/hipcc", "-shared", "-fPIC", + f"--offload-arch={ARCH}", "--hip-link", + str(obj), str(STATIC), "-o", str(lib)], capture_output=True) + return lib if ln.returncode == 0 else None + + +def meas(so: Path, M: int, N: int, K: int) -> float | None: + if not so or not Path(so).exists(): + return None + payload = json.dumps({"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, + "kernel_name": "x"}) + env = os.environ.copy() + env["HIP_VISIBLE_DEVICES"] = DEVICE + env["GEMM_PYPATH"] = PYPATH + best = None + for _ in range(REPEATS): + p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, env=env) + for line in p.stdout.decode().splitlines(): + try: + d = json.loads(line) + except json.JSONDecodeError: + continue + if d.get("ok"): + best = d["tflops"] if best is None else max(best, d["tflops"]) + return best + + +def main(): + stems = sys.argv[1:] or DEFAULT_STEMS + print(f"{'shape':>14} {'bridge':>9} {'oldTE':>9} {'gap%':>7} kernel") + for stem in stems: + old_so = build_old_so(stem) + br_so = BR_SO_DIR / f"libgemm_{stem}.so" + if old_so is None: + print(f" [skip: no old-TE header] {stem}") + continue + for (M, N, K) in SHAPES: + b = meas(br_so, M, N, K) + o = meas(old_so, M, N, K) + gap = (b - o) / o * 100 if (b and o) else float("nan") + print(f"{f'{M}x{N}x{K}':>14} {b or float('nan'):9.2f} " + f"{o or float('nan'):9.2f} {gap:7.2f} {stem[:40]}") + + +if __name__ == "__main__": + main() diff --git a/projects/composablekernel/dispatcher/python/ctypes_utils.py b/projects/composablekernel/dispatcher/python/ctypes_utils.py index d719d1405e5d..cc94ede685c9 100644 --- a/projects/composablekernel/dispatcher/python/ctypes_utils.py +++ b/projects/composablekernel/dispatcher/python/ctypes_utils.py @@ -1073,7 +1073,7 @@ def _generate_single_kernel_subprocess(args: dict) -> Tuple[bool, Optional[str], "--config", config_file, "--variants", - "standard", + args.get("variant", "standard"), ] res = subprocess.run(cmd, capture_output=True, text=True, timeout=300) diff --git a/projects/composablekernel/dispatcher/python/gemm_utils.py b/projects/composablekernel/dispatcher/python/gemm_utils.py new file mode 100644 index 000000000000..242620f2a9eb --- /dev/null +++ b/projects/composablekernel/dispatcher/python/gemm_utils.py @@ -0,0 +1,825 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +""" +GEMM Tile Engine <-> Dispatcher bridge. + +This is the GEMM counterpart of ``grouped_conv_utils.py`` / ``fmha_utils.py``: +a single shared config dataclass (``GemmKernelConfig``) that Tile Engine imports +and hands back to the dispatcher. There is no translator between two +vocabularies -- both sides share the one object whose ``.name`` mirrors the +kernel identifier baked into the generated kernel header. + +Public surface (mirrors the grouped_conv bridge): + + GemmKernelConfig -- the shared contract dataclass + .name -- registry/runtime lookup key (byte-exact) + .to_codegen_json() -- feeds unified_gemm_codegen.py + GemmProblem -- a single (M, N, K) problem + setup_multiple_gemm_dispatchers -- codegen + hipcc -> .so paths (NO GPU) + GemmDispatcherLib -- thin ctypes ABI wrapper + GpuGemmRunner -- GPU memory + run + time (from a .so path) + expand_sweep -- TE JSON sweep config -> [GemmKernelConfig] + +The heavy lifting for codegen and compilation is reused from ``ctypes_utils`` +so there is a single source of truth for how a kernel header is produced and +how it is compiled into a ``.so``. +""" + +from __future__ import annotations + +import ctypes +import itertools +import json +import multiprocessing +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +# Reuse the proven codegen/compile leaf helpers from the dispatcher's own +# python layer. gemm_utils is a thin bridge on top of these. +import ctypes_utils as _cu + + +# ============================================================================ +# Layout / dtype helpers +# ============================================================================ + +_LAYOUT_CHAR = {"row": "r", "col": "c", "r": "r", "c": "c"} +_LAYOUT_WORD = {"r": "row", "c": "col"} + + +def _cap(flag: bool) -> str: + """Reproduce Python ``str(bool).capitalize()`` -> 'True' / 'False'.""" + return "True" if flag else "False" + + +# ============================================================================ +# The shared contract: GemmKernelConfig +# ============================================================================ + + +@dataclass +class GemmKernelConfig: + """The common config struct shared by Tile Engine and the Dispatcher. + + Naming convention (the "warp/wave trap" lives here, in ONE place): + * ``wave_m/n/k`` -- warps per block (C++ ``wave_shape``; TE "warp"). + * ``warp_tile_m/n/k`` -- MFMA instruction shape (C++ ``warp_tile_shape``; + TE "warp_tile"). + """ + + # --- Signature: what operation is computed ----------------------------- + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + # --- Algorithm: how it is implemented ---------------------------------- + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + pipeline: str = "compv4" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + persistent: bool = False + + gfx_arch: str = "gfx942" + variant: str = "standard" + + # ------------------------------------------------------------------ # + # Derived string fragments + # ------------------------------------------------------------------ # + @property + def layout(self) -> str: + """3-char layout string, e.g. 'rcr'.""" + return ( + _LAYOUT_CHAR[self.layout_a] + + _LAYOUT_CHAR[self.layout_b] + + _LAYOUT_CHAR[self.layout_c] + ) + + @property + def tile_str(self) -> str: + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + @property + def wave_str(self) -> str: + return f"{self.wave_m}x{self.wave_n}x{self.wave_k}" + + @property + def warp_tile_str(self) -> str: + return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}" + + @property + def name(self) -> str: + """Registry / runtime lookup key. + + Reproduces, byte-for-byte, the ``KERNEL_NAME`` that + ``unified_gemm_codegen.py::KernelNaming.generate`` bakes into the + generated kernel header (and that the .so reports via + ``dispatcher_get_kernel_name``). This is the single thread tying + config -> codegen -> runtime together. + """ + name = ( + f"gemm_{self.dtype_a}_{self.layout}" + f"_{self.pipeline}_{self.epilogue}_{self.scheduler}" + f"_{_cap(self.pad_m)}_{_cap(self.pad_n)}_{_cap(self.pad_k)}" + f"_{_cap(self.persistent)}" + f"_{self.tile_str}_{self.wave_str}_{self.warp_tile_str}" + ) + if self.variant == "preshuffle": + name += "_preshuffle" + elif self.variant == "streamk": + name += "_streamk" + return name + + # ------------------------------------------------------------------ # + # Serialization + # ------------------------------------------------------------------ # + def to_codegen_json(self) -> Dict[str, Any]: + """Single-config JSON consumed by unified_gemm_codegen.py. + + Note the warp/wave mapping: the codegen calls the warps-per-block + triple ``warp_*`` and the MFMA triple ``warp_tile_*``. We translate + from dispatcher semantics here so the mapping cannot drift. + """ + return { + "tile_config": { + "tile_m": [self.tile_m], + "tile_n": [self.tile_n], + "tile_k": [self.tile_k], + # dispatcher wave_* -> codegen warp_* (warps per block) + "warp_m": [self.wave_m], + "warp_n": [self.wave_n], + "warp_k": [self.wave_k], + # dispatcher warp_tile_* -> codegen warp_tile_* (MFMA shape) + "warp_tile_m": [self.warp_tile_m], + "warp_tile_n": [self.warp_tile_n], + "warp_tile_k": [self.warp_tile_k], + }, + "trait_config": { + "pipeline": [self.pipeline], + "epilogue": [self.epilogue], + "scheduler": [self.scheduler], + "pad_m": [self.pad_m], + "pad_n": [self.pad_n], + "pad_k": [self.pad_k], + "persistent": [self.persistent], + }, + } + + def to_dict(self) -> Dict[str, Any]: + return { + "dtype_a": self.dtype_a, + "dtype_b": self.dtype_b, + "dtype_c": self.dtype_c, + "dtype_acc": self.dtype_acc, + "layout": self.layout, + "tile": [self.tile_m, self.tile_n, self.tile_k], + "wave": [self.wave_m, self.wave_n, self.wave_k], + "warp_tile": [self.warp_tile_m, self.warp_tile_n, self.warp_tile_k], + "pipeline": self.pipeline, + "scheduler": self.scheduler, + "epilogue": self.epilogue, + "pad": [self.pad_m, self.pad_n, self.pad_k], + "persistent": self.persistent, + "gfx_arch": self.gfx_arch, + "variant": self.variant, + "name": self.name, + } + + def to_ctypes_config(self) -> "_cu.KernelConfig": + """Convert to the ctypes_utils.KernelConfig used by the codegen/validate + helpers. ctypes_utils renames the MFMA triple ``warp_*`` (no _tile).""" + return _cu.KernelConfig( + dtype_a=self.dtype_a, + dtype_b=self.dtype_b, + dtype_c=self.dtype_c, + dtype_acc=self.dtype_acc, + layout_a=_LAYOUT_WORD[_LAYOUT_CHAR[self.layout_a]], + layout_b=_LAYOUT_WORD[_LAYOUT_CHAR[self.layout_b]], + layout_c=_LAYOUT_WORD[_LAYOUT_CHAR[self.layout_c]], + tile_m=self.tile_m, + tile_n=self.tile_n, + tile_k=self.tile_k, + wave_m=self.wave_m, + wave_n=self.wave_n, + wave_k=self.wave_k, + warp_m=self.warp_tile_m, + warp_n=self.warp_tile_n, + warp_k=self.warp_tile_k, + pipeline=self.pipeline, + scheduler=self.scheduler, + epilogue=self.epilogue, + pad_m=self.pad_m, + pad_n=self.pad_n, + pad_k=self.pad_k, + gfx_arch=self.gfx_arch, + variant=self.variant, + ) + + +# ============================================================================ +# Problem +# ============================================================================ + + +@dataclass +class GemmProblem: + """A single GEMM problem: C[MxN] = A[MxK] @ B[KxN].""" + + M: int + N: int + K: int + + @property + def flops(self) -> float: + return 2.0 * self.M * self.N * self.K + + def to_dict(self) -> Dict[str, int]: + return {"M": self.M, "N": self.N, "K": self.K} + + @classmethod + def from_dict(cls, d: Dict[str, int]) -> "GemmProblem": + return cls(M=int(d["M"]), N=int(d["N"]), K=int(d["K"])) + + +@dataclass +class GemmResult: + output: np.ndarray + time_ms: float + status: int + tflops: float + kernel_name: str + + @property + def success(self) -> bool: + return self.status == 0 + + +# ============================================================================ +# ctypes ABI wrapper +# ============================================================================ + + +class GemmDispatcherLib: + """Thin ctypes wrapper around a compiled GEMM dispatcher .so. + + Supports both the legacy single-kernel ABI (``dispatcher_get_kernel_name``) + and the multi-kernel ABI (``dispatcher_get_kernel_name_at(index, buf, n)``) + so one .so can report a whole batch and be selected by name. + """ + + def __init__(self, so_path: Path): + self._path = Path(so_path) + self._lib = ctypes.CDLL(str(self._path)) + self._has_indexed = hasattr(self._lib, "dispatcher_get_kernel_name_at") + self._setup_functions() + + def _setup_functions(self) -> None: + lib = self._lib + + lib.dispatcher_initialize.argtypes = [] + lib.dispatcher_initialize.restype = ctypes.c_int + + lib.dispatcher_get_kernel_count.argtypes = [] + lib.dispatcher_get_kernel_count.restype = ctypes.c_int + + lib.dispatcher_get_kernel_name.argtypes = [] + lib.dispatcher_get_kernel_name.restype = ctypes.c_char_p + + if self._has_indexed: + lib.dispatcher_get_kernel_name_at.argtypes = [ + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + lib.dispatcher_get_kernel_name_at.restype = ctypes.c_int + + lib.dispatcher_run_gemm.argtypes = [ + ctypes.c_void_p, # A (host) + ctypes.c_void_p, # B (host) + ctypes.c_void_p, # C (host) + ctypes.c_int64, # M + ctypes.c_int64, # N + ctypes.c_int64, # K + ctypes.POINTER(ctypes.c_float), # time_ms + ] + lib.dispatcher_run_gemm.restype = ctypes.c_int + + lib.dispatcher_cleanup.argtypes = [] + lib.dispatcher_cleanup.restype = None + + @property + def path(self) -> Path: + return self._path + + def initialize(self) -> bool: + return self._lib.dispatcher_initialize() == 0 + + def get_kernel_count(self) -> int: + return int(self._lib.dispatcher_get_kernel_count()) + + @property + def kernel_names(self) -> List[str]: + """List every kernel the .so exposes, by index when available.""" + if self._has_indexed: + names: List[str] = [] + count = self.get_kernel_count() + buf = ctypes.create_string_buffer(256) + for i in range(count): + if self._lib.dispatcher_get_kernel_name_at(i, buf, 256) == 0: + names.append(buf.value.decode("utf-8")) + if names: + return names + # Legacy single-kernel fallback. + raw = self._lib.dispatcher_get_kernel_name() + return [raw.decode("utf-8")] if raw else [] + + def run( + self, A: np.ndarray, B: np.ndarray, C: np.ndarray, M: int, N: int, K: int + ) -> Tuple[int, float]: + time_ms = ctypes.c_float(0.0) + status = self._lib.dispatcher_run_gemm( + A.ctypes.data_as(ctypes.c_void_p), + B.ctypes.data_as(ctypes.c_void_p), + C.ctypes.data_as(ctypes.c_void_p), + M, + N, + K, + ctypes.byref(time_ms), + ) + return status, time_ms.value + + def cleanup(self) -> None: + self._lib.dispatcher_cleanup() + + +# ============================================================================ +# GPU runner (constructed from a .so path; loaded only inside a worker) +# ============================================================================ + + +def _fp32_to_bf16_u16(x: np.ndarray) -> np.ndarray: + """Encode fp32 -> bfloat16 bit pattern in a uint16 array (round-to-nearest-even). + + numpy has no native bf16, but the C ABI only cares about the 2-byte memory + layout (sizeof(bf16_t) == 2 == sizeof(uint16)). Truncating the low 16 bits of + the fp32 representation with round-to-nearest-even matches ck_tile's bf16. + """ + u32 = np.ascontiguousarray(x, dtype=np.float32).view(np.uint32) + # round-to-nearest-even: add (lsb-of-kept-bits + 0x7FFF) before truncating + rounding = ((u32 >> 16) & 1) + np.uint32(0x7FFF) + return ((u32 + rounding) >> 16).astype(np.uint16) + + +def _bf16_u16_to_fp32(u16: np.ndarray) -> np.ndarray: + """Decode a uint16 bf16 bit pattern back to fp32 (low 16 mantissa bits zero).""" + return (u16.astype(np.uint32) << 16).view(np.float32) + + +def _dtype_from_kernel_name(name: str) -> str: + """Extract the dtype token from a kernel name like ``gemm___...``.""" + parts = name.split("_") + return parts[1] if len(parts) > 1 else "fp16" + + +def _layout_from_kernel_name(name: str) -> str: + """Extract the 3-char layout token (e.g. 'rcr') from a kernel name. + + Name format is ``gemm___...``; each char is 'r' (row-major) + or 'c' (column-major) for operands A, B, C respectively. + """ + parts = name.split("_") + if len(parts) > 2 and len(parts[2]) == 3 and set(parts[2]) <= {"r", "c"}: + return parts[2] + return "rcr" + + +class GpuGemmRunner: + """High-level runner: construct from a .so path, call run(A, B, problem). + + The GEMM ctypes ABI takes HOST pointers and manages GPU memory internally + (hipMalloc/hipMemcpy/hipFree), so this runner stays simple -- it hands + numpy arrays straight to the .so. + """ + + def __init__(self, lib_path: Path): + self.lib = GemmDispatcherLib(lib_path) + if not self.lib.initialize(): + raise RuntimeError(f"Failed to initialize dispatcher .so: {lib_path}") + names = self.lib.kernel_names + self._kernel_name = names[0] if names else "unknown" + + @property + def kernel_name(self) -> str: + return self._kernel_name + + def run( + self, A: np.ndarray, B: np.ndarray, problem: GemmProblem + ) -> GemmResult: + M, N, K = problem.M, problem.N, problem.K + + # Caller passes logical A (MxK) and B (KxN) row-major. The compiled + # kernel dictates both the element dtype and the memory layout of each + # operand (encoded in its name, e.g. gemm_bf16_rcr_...). The C ABI sizes + # its device buffers from sizeof(ADataType) and the kernel computes + # strides from its compiled layout + M,N,K -- so the host buffers must + # be laid out byte-for-byte in the order the kernel expects. + # + # For a 'c' (column-major) operand we transpose so the contiguous host + # buffer's flat memory matches column-major order: + # col-major A (MxK) <=> ascontiguousarray(A.T) (KxM row-major) + # Likewise column-major C (MxN) lands in memory as NxM row-major, so we + # allocate (N,M) and transpose the result back to logical MxN. + dtype = _dtype_from_kernel_name(self._kernel_name) + la, lb, lc = _layout_from_kernel_name(self._kernel_name) + + A_lay = A if la == "r" else A.T + B_lay = B if lb == "r" else B.T + C_shape = (M, N) if lc == "r" else (N, M) + + if dtype == "bf16": + # _fp32_to_bf16_u16 already forces a contiguous float32 buffer, so + # an outer ascontiguousarray here would only add a redundant copy. + A_h = _fp32_to_bf16_u16(A_lay) + B_h = _fp32_to_bf16_u16(B_lay) + C_h = np.zeros(C_shape, dtype=np.uint16) + else: # fp16 (default) + A_h = np.ascontiguousarray(A_lay, dtype=np.float16) + B_h = np.ascontiguousarray(B_lay, dtype=np.float16) + C_h = np.zeros(C_shape, dtype=np.float16) + + status, time_ms = self.lib.run(A_h, B_h, C_h, M, N, K) + + C_dec = _bf16_u16_to_fp32(C_h) if dtype == "bf16" else C_h + C_out = C_dec if lc == "r" else C_dec.T + + tflops = (problem.flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0.0 + return GemmResult( + output=C_out, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self._kernel_name, + ) + + +# ============================================================================ +# Build API: codegen + hipcc -> .so paths (no GPU) +# ============================================================================ + +# AMDGPU codegen flags Tile Engine passes to hipcc for GEMM kernels (see +# tile_engine/ops/gemm/gemm_universal CMake flags). They steer inlining and +# register allocation; omitting them changes occupancy and, because persistent +# kernels size their grid by occupancy, produces large perf gaps vs Tile Engine. +# Matching them keeps the bridge byte-for-byte performance-equivalent. +_TILE_ENGINE_CODEGEN_FLAGS = ( + "-mllvm", "--lsr-drop-solution=1", + "-mllvm", "-enable-post-misched=0", + "-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false", + "-fno-offload-uniform-block", +) + + +def _build_compile_jobs( + config: GemmKernelConfig, header: Path +) -> Tuple[Dict[str, Any], Path]: + """Replicate the (validated) compile+link commands from ctypes_utils.""" + root = _cu.get_dispatcher_root() + ck_root = root.parent + build_dir = _cu.get_build_dir() + output_dir = _cu.get_generated_kernels_dir() + ctypes_source = root / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + lib_path = build_dir / "examples" / f"lib{config.name}.so" + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{str(output_dir)}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + # Match Tile Engine's AMDGPU codegen flags. Without them the kernel is + # compiled with different inlining/register allocation, which changes + # occupancy; persistent kernels size their grid by occupancy + # (UniversalGemmKernel::MaxOccupancyGridSize = #CUs x occupancy), so the + # mismatch shows up as large perf gaps vs Tile Engine on persistent tiles. + *_TILE_ENGINE_CODEGEN_FLAGS, + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + job = {"compile_cmd": compile_cmd, "link_cmd": link_cmd, "lib_path": str(lib_path)} + return job, lib_path + + +def setup_multiple_gemm_dispatchers( + configs: List[GemmKernelConfig], + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[Optional[Path]]: + """Codegen + compile each config into its own .so. Returns .so paths. + + This is the build half of the bridge. It touches NO GPU -- pure CPU + codegen + hipcc, run massively in parallel -- and returns only ``Path`` + objects (``None`` for configs that failed to generate/compile), aligned to + the input order. Benchmarking happens later, in an isolated worker. + """ + import sys + + n = len(configs) + results: List[Optional[Path]] = [None] * n + if n == 0: + return results + + max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + # Dedupe identical configs by name; compile once, share the path. + first_index: Dict[str, int] = {} + unique: List[int] = [] + for i, c in enumerate(configs): + key = c.name + if key not in first_index: + first_index[key] = i + unique.append(i) + + codegen_script = _cu.get_codegen_path() + output_dir = _cu.get_generated_kernels_dir() + static_lib = _cu.get_build_dir() / "libck_tile_dispatcher.a" + ctypes_source = ( + _cu.get_dispatcher_root() / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + ) + if not static_lib.exists() or not ctypes_source.exists(): + raise FileNotFoundError( + "Missing static lib or ctypes source required for compilation:\n" + f" {static_lib}\n {ctypes_source}\n" + "Build the dispatcher first (cmake + make)." + ) + + # -- Step 1: parallel codegen (one header per unique config) ----------- + codegen_args = [] + for i in unique: + c = configs[i] + codegen_args.append( + { + "index": i, + "python": sys.executable, + "codegen_script": str(codegen_script), + "output_dir": str(output_dir), + "dtype": c.dtype_a, + "layout": c.layout, + "gpu_target": c.gfx_arch, + "tile_config_json": c.to_codegen_json(), + "hpp_glob_pattern": f"{c.name}.hpp", + # Honor the config's variant so non-standard kernels are codegen'd + # as themselves; the kernel name (and thus hpp_glob_pattern) already + # carries the variant suffix, so a missing/standard value here would + # produce a header whose name never matches the requested pattern. + "variant": c.variant, + } + ) + + if verbose: + print( + f"[gemm-bridge] codegen: {len(codegen_args)} headers " + f"(workers={max_workers})..." + ) + + headers: Dict[int, Path] = {} + with ProcessPoolExecutor(max_workers=max_workers) as ex: + futs = { + ex.submit(_cu._generate_single_kernel_subprocess, a): a["index"] + for a in codegen_args + } + for fut in as_completed(futs): + i = futs[fut] + ok, hdr, err = fut.result() + if ok and hdr: + headers[i] = Path(hdr) + if verbose: + print(f" OK codegen [{i}] {configs[i].name}") + elif verbose: + print(f" FAIL codegen [{i}] {configs[i].name}: {err}") + + # -- Step 2: parallel compile + link ----------------------------------- + compile_jobs = [] + job_index: List[int] = [] + for i in unique: + hdr = headers.get(i) + if hdr is None: + continue + job, _ = _build_compile_jobs(configs[i], hdr) + compile_jobs.append(job) + job_index.append(i) + + if verbose and compile_jobs: + print( + f"[gemm-bridge] compile: {len(compile_jobs)} .so " + f"(workers={max_workers})..." + ) + + with ProcessPoolExecutor(max_workers=max_workers) as ex: + futs = { + ex.submit(_cu._run_hipcc_subprocess, job): job_index[j] + for j, job in enumerate(compile_jobs) + } + for fut in as_completed(futs): + i = futs[fut] + ok, lp, err = fut.result() + if ok and lp: + results[i] = Path(lp) + if verbose: + print(f" OK compile [{i}] {Path(lp).name}") + elif verbose: + print(f" FAIL compile [{i}] {configs[i].name}: {err}") + + # -- Fan the deduped result back out to every input index -------------- + for i, c in enumerate(configs): + if results[i] is None: + results[i] = results[first_index[c.name]] + + if verbose: + ok_count = sum(1 for r in results if r is not None) + print(f"[gemm-bridge] setup complete: {ok_count}/{n} configs -> .so") + + return results + + +# ============================================================================ +# TE sweep config expansion +# ============================================================================ + + +def _expand_range(entry: Dict[str, Any]) -> List[int]: + """Expand a tile_config entry: either {min,max,step} or {values:[...]}.""" + if "values" in entry: + return list(entry["values"]) + lo = int(entry["min"]) + hi = int(entry["max"]) + step = int(entry.get("step", 1)) + return list(range(lo, hi + 1, step)) + + +def _expand_values(entry: Optional[Dict[str, Any]], default: List[Any]) -> List[Any]: + if entry is None: + return list(default) + return list(entry.get("values", default)) + + +def expand_sweep( + config_path: str, + arch: str, + dtype: str = "fp16", + layout: str = "rcr", +) -> List[GemmKernelConfig]: + """Expand a Tile Engine GEMM JSON sweep config into GemmKernelConfig list. + + The TE config uses ``tile_config`` (ranges/value-lists for tile, warp and + warp_tile triples) and ``trait_config`` (value-lists for pipeline, + scheduler, epilogue, pad_*, persistent). Every valid combination becomes + one GemmKernelConfig. Invalid combinations are dropped via the dispatcher's + own validator, and duplicates (by .name) are collapsed. + + For Phase 1 the signature is fixed to fp16 / rcr. + """ + with open(config_path) as f: + cfg = json.load(f) + + tc = cfg.get("tile_config", {}) + tr = cfg.get("trait_config", {}) + + tile_ms = _expand_range(tc["tile_m"]) + tile_ns = _expand_range(tc["tile_n"]) + tile_ks = _expand_range(tc["tile_k"]) + wave_ms = _expand_range(tc["warp_m"]) # TE "warp" == wave count + wave_ns = _expand_range(tc["warp_n"]) + wave_ks = _expand_range(tc["warp_k"]) + wt_ms = _expand_range(tc["warp_tile_m"]) + wt_ns = _expand_range(tc["warp_tile_n"]) + wt_ks = _expand_range(tc["warp_tile_k"]) + + pipelines = _expand_values(tr.get("pipeline"), ["compv3"]) + schedulers = _expand_values(tr.get("scheduler"), ["intrawave"]) + epilogues = _expand_values(tr.get("epilogue"), ["cshuffle"]) + pad_ms = _expand_values(tr.get("pad_m"), [False]) + pad_ns = _expand_values(tr.get("pad_n"), [False]) + pad_ks = _expand_values(tr.get("pad_k"), [False]) + persistents = _expand_values(tr.get("persistent"), [False]) + + la, lb, lc = layout[0], layout[1], layout[2] + + configs: List[GemmKernelConfig] = [] + seen: set = set() + for ( + tm, + tn, + tk, + wm, + wn, + wk, + wtm, + wtn, + wtk, + pipe, + sched, + epi, + pm, + pn, + pk, + persist, + ) in itertools.product( + tile_ms, + tile_ns, + tile_ks, + wave_ms, + wave_ns, + wave_ks, + wt_ms, + wt_ns, + wt_ks, + pipelines, + schedulers, + epilogues, + pad_ms, + pad_ns, + pad_ks, + persistents, + ): + c = GemmKernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + layout_a=_LAYOUT_WORD[la], + layout_b=_LAYOUT_WORD[lb], + layout_c=_LAYOUT_WORD[lc], + tile_m=tm, + tile_n=tn, + tile_k=tk, + wave_m=wm, + wave_n=wn, + wave_k=wk, + warp_tile_m=wtm, + warp_tile_n=wtn, + warp_tile_k=wtk, + pipeline=pipe, + scheduler=sched, + epilogue=epi, + pad_m=bool(pm), + pad_n=bool(pn), + pad_k=bool(pk), + persistent=bool(persist), + gfx_arch=arch, + ) + if c.name in seen: + continue + val = _cu.validate_kernel_config(c.to_ctypes_config()) + if not val.is_valid: + continue + seen.add(c.name) + configs.append(c) + + return configs diff --git a/projects/composablekernel/dispatcher/tests/test_gemm_parity.py b/projects/composablekernel/dispatcher/tests/test_gemm_parity.py new file mode 100644 index 000000000000..308c5700672c --- /dev/null +++ b/projects/composablekernel/dispatcher/tests/test_gemm_parity.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""GEMM bridge parity regression: Dispatcher GPU output vs NumPy reference. + +This is the in-tree, reproducible version of the ad-hoc ``parity/`` sweep used to +validate the Tile Engine -> Dispatcher GEMM bridge. For each (dtype, layout) the +bridge supports it codegens + hipcc-compiles a kernel, runs it through +``GpuGemmRunner``, and compares the result to a NumPy reference across a square, a +rectangular, and an awkward (non-tile-aligned) problem shape. + +Parity is checked as a GLOBAL relative error -- ``max|gpu - ref| / max|ref|`` -- +not per-element: K-length accumulation of zero-mean inputs produces near-zero +entries whose per-element ratios explode and carry no signal. + +The whole suite is GPU-gated: it skips cleanly (not fails) when hipcc, the +dispatcher static lib, or a GPU is unavailable, so CPU-only CI stays green while +GPU runners get real end-to-end coverage. The pure host-side helpers are covered +separately and cheaply by ``test_gemm_utils.py``. + +Run: + python3 -m pytest tests/test_gemm_parity.py -v # discovery / CI + python3 tests/test_gemm_parity.py # readable table +""" + +import os +import sys +import shutil +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +import numpy as np # noqa: E402 + +from gemm_utils import ( # noqa: E402 + GemmKernelConfig, + GemmProblem, + GpuGemmRunner, + setup_multiple_gemm_dispatchers, + _fp32_to_bf16_u16, + _bf16_u16_to_fp32, +) +from ctypes_utils import detect_gpu_arch, get_build_dir # noqa: E402 + +# (dtype, layout) surface the regular bridge supports. Column-major C is rejected +# by ck_tile's universal GEMM at build, so every layout keeps row-major C, which +# leaves exactly the four A/B combinations below. Both dtypes cover all four. +_CASES = [ + ("fp16", "rcr"), + ("fp16", "rrr"), + ("fp16", "ccr"), + ("fp16", "crr"), + ("bf16", "rcr"), + ("bf16", "rrr"), + ("bf16", "ccr"), + ("bf16", "crr"), +] + +# Padded default algorithm: pad_* all True so M/N need not divide the tile, which +# is what lets the awkward shape below pass. K must still be a multiple of 8 for +# the fp16/bf16 vectorized contiguous-reduction load, so every K here is divisible +# by 8. +_ALGO = dict( + tile_m=128, tile_n=128, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave", epilogue="cshuffle", + pad_m=True, pad_n=True, pad_k=True, +) + +# (name, M, N, K). 'awkward' deliberately uses M, N that do not divide the 128 +# tile to exercise padding; K stays divisible by 8. +_SHAPES = [ + ("square", 512, 512, 512), + ("rectangular", 1024, 512, 256), + ("awkward", 257, 129, 512), +] + +# Global-relative-error gates. fp16 measured ~3-4e-4 and bf16 ~8e-3 on gfx942; +# these leave headroom without masking a real regression. +_TOL = {"fp16": 2e-3, "bf16": 1.5e-2} + +_LAYOUT_WORD = {"r": "row", "c": "col"} + + +def _emulate(x: np.ndarray, dtype: str) -> np.ndarray: + """Round fp32 to the kernel's storage dtype so the CPU reference matches what + the GPU actually multiplies (and stores).""" + if dtype == "bf16": + return _bf16_u16_to_fp32(_fp32_to_bf16_u16(x)) + return x.astype(np.float16).astype(np.float32) + + +def _config(dtype: str, layout: str, arch: str) -> GemmKernelConfig: + la, lb, lc = layout + return GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + layout_a=_LAYOUT_WORD[la], layout_b=_LAYOUT_WORD[lb], layout_c=_LAYOUT_WORD[lc], + gfx_arch=arch, **_ALGO, + ) + + +def _max_rel(out: np.ndarray, ref: np.ndarray) -> float: + denom = float(np.max(np.abs(ref))) + 1e-12 + return float(np.max(np.abs(out - ref))) / denom + + +def _gpu_environment_reason(): + """Return None if the bridge can build+run here, else a human-readable reason + to skip.""" + if not Path("/opt/rocm/bin/hipcc").exists(): + return "hipcc not found at /opt/rocm/bin/hipcc" + if not (get_build_dir() / "libck_tile_dispatcher.a").exists(): + return "dispatcher static lib (libck_tile_dispatcher.a) not built" + if shutil.which("rocminfo") is None: + return "rocminfo not found (no ROCm runtime / GPU)" + return None + + +class GemmBridgeParity(unittest.TestCase): + """End-to-end GPU-vs-NumPy parity across the bridge's dtype/layout surface.""" + + arch = None + built = {} # (dtype, layout) -> Path(.so) + build_failures = {} + + @classmethod + def setUpClass(cls): + reason = _gpu_environment_reason() + if reason: + raise unittest.SkipTest(reason) + cls.arch = detect_gpu_arch() + + configs = [_config(dt, lay, cls.arch) for dt, lay in _CASES] + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + for (dt, lay), so in zip(_CASES, so_paths): + if so is None: + cls.build_failures[(dt, lay)] = "codegen/hipcc returned no .so" + else: + cls.built[(dt, lay)] = so + + if not cls.built: + raise unittest.SkipTest( + f"no bridge kernels built on {cls.arch} " + f"(failures: {cls.build_failures})" + ) + + def _run_case(self, dtype, layout, shape): + so = self.built.get((dtype, layout)) + if so is None: + self.skipTest( + f"{dtype}/{layout} did not build on {self.arch}: " + f"{self.build_failures.get((dtype, layout))}" + ) + + _, M, N, K = shape + problem = GemmProblem(M=M, N=N, K=K) + rng = np.random.default_rng(42) + A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + + runner = GpuGemmRunner(lib_path=so) + # The .so is the contract endpoint: the name it reports must be the config + # name that drove codegen + the force-include build. + self.assertEqual(runner.kernel_name, GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + layout_a=_LAYOUT_WORD[layout[0]], layout_b=_LAYOUT_WORD[layout[1]], + layout_c=_LAYOUT_WORD[layout[2]], gfx_arch=self.arch, **_ALGO, + ).name) + + result = runner.run(A, B, problem) + self.assertTrue( + result.success, + f"{dtype}/{layout} {shape[0]} run failed (status {result.status})", + ) + + ref = _emulate(_emulate(A, dtype) @ _emulate(B, dtype), dtype) + max_rel = _max_rel(result.output, ref) + self.assertLessEqual( + max_rel, _TOL[dtype], + f"{dtype}/{layout} {shape[0]} max_rel={max_rel:.2e} > {_TOL[dtype]:.0e}", + ) + + +def _add_parity_tests(): + """Generate one test method per (case, shape) so failures pinpoint exactly + which dtype/layout/shape regressed.""" + for dtype, layout in _CASES: + for shape in _SHAPES: + shape_name = shape[0] + + def _method(self, dtype=dtype, layout=layout, shape=shape): + self._run_case(dtype, layout, shape) + + _method.__name__ = f"test_{dtype}_{layout}_{shape_name}" + _method.__doc__ = f"{dtype} {layout} {shape_name} {shape[1:]} parity" + setattr(GemmBridgeParity, _method.__name__, _method) + + +_add_parity_tests() + + +def _main() -> int: + """Readable table run (mirrors test_fmha_parity.py's report style).""" + reason = _gpu_environment_reason() + if reason: + print(f"SKIP: {reason}") + return 0 + + arch = detect_gpu_arch() + print("=" * 78) + print(f"GEMM Bridge Parity: Dispatcher (GPU {arch}) vs NumPy reference") + print("=" * 78) + + configs = [_config(dt, lay, arch) for dt, lay in _CASES] + print(f" Building {len(configs)} bridge kernels (codegen + hipcc)...") + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + + print(f"\n {'case':<12} {'shape':<12} {'tflops':>9} {'max_rel':>10} {'tol':>8} {'':>6}") + print(" " + "-" * 60) + + rng = np.random.default_rng(42) + total = 0 + passed = 0 + for (dtype, layout), so in zip(_CASES, so_paths): + tag = f"{dtype}/{layout}" + if so is None: + print(f" {tag:<12} {'-':<12} {'BUILD FAILED':>35}") + total += len(_SHAPES) + continue + runner = GpuGemmRunner(lib_path=so) + for sname, M, N, K in _SHAPES: + total += 1 + problem = GemmProblem(M=M, N=N, K=K) + A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + result = runner.run(A, B, problem) + if not result.success: + print(f" {tag:<12} {sname:<12} {'RUN FAILED':>9} status={result.status}") + continue + ref = _emulate(_emulate(A, dtype) @ _emulate(B, dtype), dtype) + mr = _max_rel(result.output, ref) + ok = mr <= _TOL[dtype] + passed += ok + print(f" {tag:<12} {sname:<12} {result.tflops:>9.1f} " + f"{mr:>10.2e} {_TOL[dtype]:>8.0e} {'PASS' if ok else 'FAIL':>6}") + + print("\n" + "=" * 78) + print(f" {passed}/{total} parity checks passed") + print("=" * 78) + return 0 if passed == total else 1 + + +if __name__ == "__main__": + # Default to the readable table; `-m pytest` / `unittest` use the generated + # test methods instead. + if os.environ.get("GEMM_PARITY_UNITTEST"): + unittest.main() + else: + sys.exit(_main()) diff --git a/projects/composablekernel/dispatcher/tests/test_gemm_utils.py b/projects/composablekernel/dispatcher/tests/test_gemm_utils.py new file mode 100644 index 000000000000..34e07ecfcb04 --- /dev/null +++ b/projects/composablekernel/dispatcher/tests/test_gemm_utils.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""CPU-only unit tests for python/gemm_utils.py. + +Locks in the bit-level helpers that the TE -> Dispatcher GEMM bridge relies on: + * bf16 <-> uint16 encoding (round-to-nearest-even), since numpy has no native + bf16 and the runner carries bf16 as a uint16 bit pattern. + * dtype / layout parsing from the compiled kernel name, which drives how the + runner lays out host buffers. + +No GPU is touched -- all functions under test are pure host-side logic. +Run: python3 -m pytest tests/test_gemm_utils.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +import numpy as np # noqa: E402 + +from gemm_utils import ( # noqa: E402 + GemmKernelConfig, + _fp32_to_bf16_u16, + _bf16_u16_to_fp32, + _dtype_from_kernel_name, + _layout_from_kernel_name, +) + + +class TestBf16Encoding(unittest.TestCase): + """bf16 = top 16 bits of fp32 with round-to-nearest-even.""" + + def test_exactly_representable_roundtrip(self): + # Values whose low 16 fp32 mantissa bits are zero are exact in bf16. + exact = np.array([0.0, 1.0, -1.0, 2.0, 0.5, -0.5, 4.0, 256.0], + dtype=np.float32) + out = _bf16_u16_to_fp32(_fp32_to_bf16_u16(exact)) + np.testing.assert_array_equal(out, exact) + + def test_roundtrip_within_bf16_tolerance(self): + rng = np.random.default_rng(0) + x = (rng.standard_normal(10000) * 100.0).astype(np.float32) + out = _bf16_u16_to_fp32(_fp32_to_bf16_u16(x)) + # bf16 has 8 bits of significand -> relative error <= 2^-8. + rel = np.abs(out - x) / (np.abs(x) + 1e-30) + self.assertLessEqual(float(rel.max()), 2.0 ** -8) + + def test_round_to_nearest_even_ties(self): + # Tie halfway between bf16 1.0 (0x3F80, even) and 0x3F81 (odd): + # fp32 0x3F808000 must round DOWN to the even neighbor 0x3F80. + tie_down = np.array([0x3F808000], dtype=np.uint32).view(np.float32) + self.assertEqual(int(_fp32_to_bf16_u16(tie_down)[0]), 0x3F80) + # Tie halfway between 0x3F81 (odd) and 0x3F82 (even): + # fp32 0x3F818000 must round UP to the even neighbor 0x3F82. + tie_up = np.array([0x3F818000], dtype=np.uint32).view(np.float32) + self.assertEqual(int(_fp32_to_bf16_u16(tie_up)[0]), 0x3F82) + + def test_special_values(self): + inf = np.array([np.inf, -np.inf], dtype=np.float32) + out = _bf16_u16_to_fp32(_fp32_to_bf16_u16(inf)) + self.assertTrue(np.isinf(out[0]) and out[0] > 0) + self.assertTrue(np.isinf(out[1]) and out[1] < 0) + + nan = np.array([np.nan], dtype=np.float32) + out_nan = _bf16_u16_to_fp32(_fp32_to_bf16_u16(nan)) + self.assertTrue(np.isnan(out_nan[0])) + + def test_dtype_and_size(self): + u16 = _fp32_to_bf16_u16(np.zeros(4, dtype=np.float32)) + self.assertEqual(u16.dtype, np.uint16) + self.assertEqual(u16.itemsize, 2) # must match sizeof(bf16_t) on device + + +class TestKernelNameParsing(unittest.TestCase): + """The runner reads dtype + layout straight from the compiled .so name.""" + + _NAME = ("gemm_bf16_rcr_compv3_cshuffle_intrawave_" + "False_False_False_False_64x64x64_4x1x1_16x16x16") + + def test_dtype_from_name(self): + self.assertEqual(_dtype_from_kernel_name(self._NAME), "bf16") + self.assertEqual( + _dtype_from_kernel_name("gemm_fp16_rrr_compv4_cshuffle_intrawave"), + "fp16", + ) + + def test_dtype_fallback(self): + # Malformed / single-token name falls back to fp16. + self.assertEqual(_dtype_from_kernel_name("gemm"), "fp16") + + def test_layout_from_name(self): + self.assertEqual(_layout_from_kernel_name(self._NAME), "rcr") + for lay in ("rrr", "ccr", "crr", "rcc"): + name = f"gemm_fp16_{lay}_compv3_cshuffle_intrawave" + self.assertEqual(_layout_from_kernel_name(name), lay) + + def test_layout_fallback(self): + # A token that is not a 3-char r/c string falls back to rcr. + self.assertEqual( + _layout_from_kernel_name("gemm_fp16_xyz_compv3"), "rcr" + ) + self.assertEqual(_layout_from_kernel_name("gemm"), "rcr") + + +class TestConfigNameContract(unittest.TestCase): + """GemmKernelConfig.name is the single source of truth tying config -> + codegen -> runtime; parsing it back must recover dtype and layout.""" + + def test_name_roundtrips_through_parsers(self): + for dtype in ("fp16", "bf16"): + for la, lb, lc in (("row", "col", "row"), + ("row", "row", "row"), + ("col", "col", "row"), + ("col", "row", "row")): + cfg = GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + layout_a=la, layout_b=lb, layout_c=lc, + ) + name = cfg.name + self.assertEqual(_dtype_from_kernel_name(name), dtype) + self.assertEqual(_layout_from_kernel_name(name), cfg.layout) + + +if __name__ == "__main__": + unittest.main() diff --git a/projects/composablekernel/test/ck_tile/CMakeLists.txt b/projects/composablekernel/test/ck_tile/CMakeLists.txt index 52552d8711ab..d55d3f609e86 100644 --- a/projects/composablekernel/test/ck_tile/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/CMakeLists.txt @@ -71,9 +71,6 @@ if(BUILD_CK_TILE_FMHA_TESTS) add_subdirectory(fmha) endif() if(BUILD_CK_TILE_ENGINE_TESTS) -# TODO: The Universal GEMM tile engine test will be either removed -# or moved to the appropriate location in future work. -# add_subdirectory(gemm_tile_engine) add_subdirectory(pooling_tile_engine) endif() add_subdirectory(pooling) diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt b/projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt deleted file mode 100644 index 374370f57076..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# ============================================================================ -# GEMM Tile Engine Unit Tests -# -# This CMake file creates unit tests for tile_engine generated GEMM kernels. -# It follows the exact same build patterns as tile_engine for consistency -# and reliability. Each kernel configuration gets its own test executable. -# ============================================================================ - -# Locate tile_engine GEMM scripts directory -set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm/gemm_universal") - -if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR}) - message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}") - return() -endif() - -# ============================================================================ -# create_individual_gemm_test_target -# -# Creates a single test executable for a specific kernel configuration. -# Mirrors tile_engine's create_individual_gemm_target function for consistency. -# -# Parameters: -# datatype - Data type (fp16, bf16, fp32, etc.) -# layout - Matrix layout (rcr, rrr, ccr, crr) -# config_name - Configuration file name without .json extension -# trait - Kernel trait combination string -# tile_config - Tile configuration parameters -# config_json - Full path to JSON configuration file -# ============================================================================ -function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json) - set(target_name "test_gemm_universal_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}") - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") - - # Generated header path (already created during cmake configuration) - set(test_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") - set(test_params_header "${working_path}/test_params.hpp") - - # Verify header exists (should have been generated during cmake configuration) - if(NOT EXISTS ${test_header}) - message(WARNING "Generated header not found: ${test_header}") - return() - endif() - - # Verify test parameters header exists - if(NOT EXISTS ${test_params_header}) - message(WARNING "Test parameters header not found: ${test_params_header}") - return() - endif() - - - # Create GTest executable for this kernel configuration - add_gtest_executable(${target_name} - ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp - ) - - # Configure GPU architectures for HIP compilation - set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS}) - - # Define preprocessor macros for generated header location and test parameters - target_compile_definitions(${target_name} PRIVATE - GEMM_SINGLE_INSTANCE_HPP="${test_header}" - GEMM_TEST_PARAMS_HPP="${test_params_header}" - ) - - # Include directories for headers and dependencies - target_include_directories(${target_name} PRIVATE - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_BINARY_DIR}/include - ${PROJECT_SOURCE_DIR} # Root directory for tile_engine access - ${GTEST_INCLUDE_DIRS} - ) - - # Compiler options matching tile_engine requirements - target_compile_options(${target_name} PRIVATE - -Wno-undefined-func-template # Suppress template warnings - -Wno-float-equal # Allow floating point comparisons - --offload-compress # Enable GPU code compression - -include ${test_header} # Auto-include generated header - ) - - # Add FP8 format definitions for proper data type interpretation - if(CK_USE_OCP_FP8) - target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) - endif() - - message(DEBUG " Created test target: ${target_name}") -endfunction() - -# ============================================================================ -# build_gemm_test_targets -# -# Builds all test targets for a specific datatype/layout/config combination. -# Uses tile_engine's two-step process: list kernels, then generate tests. -# -# Parameters: -# datatype - Data type (fp16, bf16, fp32, etc.) -# layout - Matrix layout (rcr, rrr, ccr, crr) -# config_name - Configuration file name without .json extension -# ============================================================================ -function(build_gemm_test_targets datatype layout config_name) - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") - - # Locate and validate configuration file - set(config_filename "${config_name}.json") - set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") - - if(NOT EXISTS ${json_blob}) - message(WARNING "Test config file not found: ${json_blob}") - return() - endif() - - # Prepare build directory for this configuration - file(MAKE_DIRECTORY ${working_path}) - - # STEP 1: Discovery phase - list all valid kernel configurations - execute_process( - COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_universal_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${json_blob} - --list_kernels - --gpu_target "${GEMM_TEST_GPU_TARGETS}" - WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} - RESULT_VARIABLE ret - OUTPUT_VARIABLE list_output - ERROR_VARIABLE list_error - ) - - if(NOT ret EQUAL 0) - message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}") - return() - endif() - - # Verify kernel list file was generated - if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) - message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") - return() - endif() - - message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}") - - # STEP 2a: Extract test parameters from config - set(test_params_file "${working_path}/test_params.hpp") - execute_process( - COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py - --config_file ${json_blob} - --output_file ${test_params_file} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - RESULT_VARIABLE extract_ret - OUTPUT_VARIABLE extract_output - ERROR_VARIABLE extract_error - ) - - if(NOT extract_ret EQUAL 0) - message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}") - return() - endif() - - # STEP 2b: Header generation phase - generate headers using --gen_single - message(STATUS " Generating headers using --gen_single...") - - file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) - set(gen_count 0) - - foreach(line IN LISTS kernel_lines) - # Parse kernel specification format: kernel_name|tile_config|trait_combo - string(REPLACE "|" ";" parts "${line}") - list(LENGTH parts parts_len) - if(parts_len EQUAL 3) - list(GET parts 0 kernel_name) - list(GET parts 1 tile_config) - list(GET parts 2 trait_combo) - - # Generate header using --gen_single - execute_process( - COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_universal_instance_builder.py - --working_path ${working_path} - --gpu_target "${GEMM_TEST_GPU_TARGETS}" - --datatype ${datatype} - --layout ${layout} - --config_json ${json_blob} - --gen_single - --kernel_name "${kernel_name}" - --tile_config "${tile_config}" - --trait_combo "${trait_combo}" - WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} - RESULT_VARIABLE gen_ret - OUTPUT_VARIABLE gen_output - ERROR_VARIABLE gen_error - ) - - if(NOT gen_ret EQUAL 0) - message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}") - else() - math(EXPR gen_count "${gen_count} + 1") - endif() - endif() - endforeach() - - message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}") - - # STEP 3: Target creation phase - create test targets - message(STATUS " Creating test targets...") - file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) - set(test_count 0) - foreach(line IN LISTS kernel_lines) - # Parse kernel specification format: kernel_name|tile_config|trait_combo - string(REPLACE "|" ";" parts "${line}") - list(LENGTH parts parts_len) - if(parts_len EQUAL 3) - list(GET parts 0 kernel_name) - list(GET parts 1 tile_config) - list(GET parts 2 trait_combo) - - # Generate test target for this kernel configuration - create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}") - math(EXPR test_count "${test_count} + 1") - endif() - endforeach() - message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}") -endfunction()# ============================================================================ -# MAIN EXECUTION - Test Target Generation -# ============================================================================ - -message(STATUS "=== Starting GEMM Tile Engine Test Configuration ===") -message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") - -# GPU architecture filtering - only build tests for supported architectures -set(GEMM_TEST_GPU_TARGETS "") -set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic") - -foreach(target IN LISTS SUPPORTED_GPU_TARGETS) - if(target IN_LIST DESIRED_TARGETS) - list(APPEND GEMM_TEST_GPU_TARGETS ${target}) - message(STATUS " Adding GPU target for tests: ${target}") - endif() -endforeach() - -# Early exit if no compatible GPU architectures are available -if(NOT GEMM_TEST_GPU_TARGETS) - message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") - return() -endif() - -message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}") - - # Enable parallel compilation optimizations - # Set up job pools for better parallel compilation control - set_property(GLOBAL PROPERTY JOB_POOLS - compile_heavy=4 # Limit heavy compilations to prevent OOM - compile_normal=16 # Allow more parallel normal compilations - ) - - # Enable compiler cache if available and explicitly requested - # Disabled by default due to permission issues in CI environments - option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF) - if(ENABLE_CCACHE_TESTS) - find_program(CCACHE_PROGRAM ccache) - if(CCACHE_PROGRAM) - set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) - message(STATUS "Using ccache for faster test compilation") - else() - message(WARNING "ccache requested but not found") - endif() - else() - message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)") - endif() - -# ============================================================================ -# Test Configuration Matrix - Clean Focused Design -# ============================================================================ - -# All supported data types and layouts for comprehensive testing -# Note: fp64 not included (no MFMA hardware support) -set(TEST_DATATYPES "fp16;fp8;bf16;fp32") -set(TEST_LAYOUTS "rcr;rrr;ccr;crr") - -# ============================================================================ -# Test Target Generation - Datatype-Specific Categories -# ============================================================================ - -# 1. SMALL DATATYPES: Test optimized config for small data types (fp8, fp16, bf16) -# These data types can use larger warp tiles due to smaller memory footprint -set(SMALL_DATATYPE_CONFIG "small_datatype_config") -set(SMALL_DATATYPE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SMALL_DATATYPE_CONFIG}.json") -set(SMALL_DATATYPES "fp8;fp16;bf16") - -if(EXISTS ${SMALL_DATATYPE_CONFIG_FILE}) - message(STATUS "Processing small datatype config: ${SMALL_DATATYPE_CONFIG} (fp8, fp16, bf16)") - foreach(datatype IN LISTS SMALL_DATATYPES) - # fp8, fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) - foreach(layout IN LISTS TEST_LAYOUTS) - build_gemm_test_targets("${datatype}" "${layout}" "${SMALL_DATATYPE_CONFIG}") - endforeach() - endforeach() -else() - message(WARNING "Small datatype config file not found: ${SMALL_DATATYPE_CONFIG_FILE}") -endif() - -# 2. PADDING COVERAGE: Test padding combinations with fixed fp16/rcr configuration -# This focuses on padding behavior (pad_m, pad_n, pad_k) -set(PADDING_CONFIG "padding_coverage_config") -set(PADDING_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${PADDING_CONFIG}.json") - -if(EXISTS ${PADDING_CONFIG_FILE}) - message(STATUS "Processing padding config: ${PADDING_CONFIG} (fp16/rcr only)") - build_gemm_test_targets("fp16" "rcr" "${PADDING_CONFIG}") -else() - message(WARNING "Padding config file not found: ${PADDING_CONFIG_FILE}") -endif() - -# 3. COVERAGE LEVEL: Quick or comprehensive testing -# Quick: ~144 kernels with multiple tile sizes and trait combinations -# Comprehensive: Several thousand kernels with extensive tile sizes, warp configurations, and all trait combinations -set(COVERAGE_LEVEL "quick" CACHE STRING "Coverage level: quick or comprehensive") -set_property(CACHE COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive") - -if(COVERAGE_LEVEL STREQUAL "quick") - set(COVERAGE_CONFIG "quick_coverage_config") - set(COVERAGE_DESC "Quick - approximately 144 kernels with trait combinations") -elseif(COVERAGE_LEVEL STREQUAL "comprehensive") - set(COVERAGE_CONFIG "comprehensive_coverage_config") - set(COVERAGE_DESC "Comprehensive - several thousand kernels with extensive tile and trait coverage") -else() - message(FATAL_ERROR "Invalid COVERAGE_LEVEL: ${COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'") -endif() - -set(COVERAGE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COVERAGE_CONFIG}.json") - -if(EXISTS ${COVERAGE_CONFIG_FILE}) - message(STATUS "Processing coverage config: ${COVERAGE_LEVEL} - ${COVERAGE_DESC}") - build_gemm_test_targets("fp16" "rcr" "${COVERAGE_CONFIG}") -else() - message(WARNING "Coverage config file not found: ${COVERAGE_CONFIG_FILE}") -endif() -# ============================================================================ - - -message(STATUS "GEMM tile engine tests configured with datatype-specific design:") -message(STATUS " - Small datatypes: fp8/fp16/bf16 (all layouts)") -message(STATUS " - Padding coverage with fp16/rcr") -message(STATUS " - Coverage level: ${COVERAGE_LEVEL} (~144 kernels quick, several thousand comprehensive)") -message(STATUS " Use -DCOVERAGE_LEVEL=comprehensive for extensive testing") diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md b/projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md deleted file mode 100644 index 87ce0c9fd05c..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# GEMM Tile Engine Unit Tests - -## How It Works - -This unit test system integrates **tile_engine's kernel generation** into automated testing: - -1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels -2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine) -3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers -4. **Individual test executables**: Each kernel configuration becomes a separate test -5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine - -## Tile Engine Integration - -``` -JSON Config → tile_engine Python scripts → Generated Headers → Test Executables -``` - -- **`--list_kernels`**: Get available kernel configurations from JSON -- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration -- **`--gen_single`**: Generate individual kernel header for each configuration -- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations -- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching - -### Config-Specific Test Parameters - -Each test configuration can specify optimized problem sizes in its JSON file: -- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations -- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files -- **Build integration**: Each test target uses parameters appropriate for its kernel configuration -- **Optimized testing**: Different configs test different problem sizes that showcase their strengths - - -The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. - -## Test Configurations - -### 1. **Simple Test** (`simple_test_config.json`) -- **Purpose**: Basic functionality validation -- **Config**: 128x128x64, warp 2x2x1, warp_tile 16x16x16 -- **Traits**: compv3 + compv4 pipelines -- **Coverage**: ~2 kernels per datatype/layout - -### 2. **Small Datatype** (`small_datatype_config.json`) -- **Purpose**: Optimized for fp8/fp16/bf16 data types -- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 -- **Traits**: compv3 pipeline only -- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp8, fp16, bf16 - -### 3. **Padding Coverage** (`padding_coverage_config.json`) -- **Purpose**: Test padding behavior with all padding flags enabled -- **Config**: Fixed 64x64x32, warp 2x2x1, warp_tile 32x32x16 -- **Padding**: All enabled (pad_m=true, pad_n=true, pad_k=true) -- **Problem sizes**: Vector-aligned but not tile-aligned (104×104×56, 200×152×80, 152×200×64) -- **Coverage**: 1 kernel configuration testing padding with irregular sizes - -### 4. **Coverage Testing** (Quick or Comprehensive) -- **Purpose**: Comprehensive testing across tile sizes, warp configurations, and trait combinations -- **Quick** (`quick_coverage_config.json`): Approximately 144 kernels - - tile_m/n: [32, 64, 256], tile_k: [16, 32] - - warp config: 2×2×1, warp_tile 16×16×16 - - Traits: 3 pipelines × 2 epilogues × 2 schedulers (persistent=false only) - - Focused set testing trait combinations with multiple tile sizes -- **Comprehensive** (`comprehensive_coverage_config.json`): Several thousand kernels - - tile_m/n: [16-256 step 16] - - tile_k: [16, 32, 64] - - warp_m/n: [1, 2, 4], warp_tile_m/n: [16, 32], warp_tile_k: [16, 32] - - Traits: 3 pipelines × 2 epilogues × 2 schedulers × 2 persistent - - Extensive coverage across all tile sizes, warp configurations, and trait combinations - - Exact count varies based on validation filtering -- **Note**: Use CMake option `-DCOVERAGE_LEVEL=comprehensive` to enable comprehensive testing (default is quick) - -## Data Type Support -- ✅ **fp8, fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) -- ❌ **fp64**: Not supported (hardware MFMA limitation) -- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) - -## Test Result Behavior - -Tests automatically handle unsupported configurations through runtime validation: -- **PASSED**: Kernel executed correctly with results within error thresholds ✅ -- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️ -- **FAILED**: Actual error or incorrect computation results ❌ - -When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements. diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json deleted file mode 100644 index f2524e4a619d..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "problem": { - "description": "Comprehensive coverage testing - extensive tile size coverage (16-256, step 16) with multiple warp configurations and all trait combinations. Several thousand kernels." - }, - "test_params": { - "problem_sizes": [ - {"m": 512, "n": 512, "k": 256, "split_k": 1}, - {"m": 1024, "n": 512, "k": 512, "split_k": 1}, - {"m": 512, "n": 1024, "k": 512, "split_k": 1}, - {"m": 1024, "n": 1024, "k": 256, "split_k": 1}, - {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, - {"m": 1024, "n": 1024, "k": 256, "split_k": 4} - ] - }, - "tile_config": { - "tile_m": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, - "tile_n": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, - "tile_k": {"values": [16, 32, 64]}, - "warp_m": {"values": [1, 2, 4]}, - "warp_n": {"values": [1, 2, 4]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [16, 32]}, - "warp_tile_n": {"values": [16, 32]}, - "warp_tile_k": {"values": [8, 16, 32, 64, 128]} - }, - "trait_config": { - "pipeline": {"values": ["mem", "compv3", "compv4"]}, - "epilogue": {"values": ["default", "cshuffle"]}, - "scheduler": {"values": ["intrawave", "interwave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [true, false]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json deleted file mode 100644 index e9fcb6fb8007..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "problem": { - "description": "Configuration optimized for large data types (fp32) with smaller warp tiles due to memory constraints" - }, - "test_params": { - "problem_sizes": [ - {"m": 512, "n": 512, "k": 128, "split_k": 1}, - {"m": 512, "n": 256, "k": 192, "split_k": 1}, - {"m": 256, "n": 384, "k": 192, "split_k": 1} - ] - }, - "tile_config": { - "tile_m": {"values": [256]}, - "tile_n": {"values": [128]}, - "tile_k": {"values": [32]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [16]}, - "warp_tile_n": {"values": [16]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["compv3"]}, - "epilogue": {"values": ["default"]}, - "scheduler": {"values": ["intrawave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json deleted file mode 100644 index 33bada839de5..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "problem": { - "description": "Padding coverage testing - fixed config with fp16/rcr, varying only padding combinations" - }, - "test_params": { - "problem_sizes": [ - {"m": 104, "n": 104, "k": 56, "split_k": 1}, - {"m": 200, "n": 152, "k": 80, "split_k": 1}, - {"m": 152, "n": 200, "k": 64, "split_k": 1} - ] - }, - "tile_config": { - "tile_m": {"values": [64]}, - "tile_n": {"values": [64]}, - "tile_k": {"values": [32]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [32]}, - "warp_tile_n": {"values": [32]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["compv3"]}, - "epilogue": {"values": ["default"]}, - "scheduler": {"values": ["intrawave"]}, - "pad_m": {"values": [true]}, - "pad_n": {"values": [true]}, - "pad_k": {"values": [true]}, - "persistent": {"values": [false]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json deleted file mode 100644 index dcc6e99aee5a..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "problem": { - "description": "Quick coverage testing - tests multiple tile sizes with all trait combinations (pipelines, epilogues, schedulers). Approximately 144 kernels." - }, - "test_params": { - "problem_sizes": [ - {"m": 512, "n": 1024, "k": 512, "split_k": 1}, - {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, - {"m": 1024, "n": 1024, "k": 256, "split_k": 4} - ] - }, - "tile_config": { - "tile_m": {"values": [32, 64, 256]}, - "tile_n": {"values": [32, 64, 256]}, - "tile_k": {"values": [16, 32]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [16]}, - "warp_tile_n": {"values": [16]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["mem", "compv3", "compv4"]}, - "epilogue": {"values": ["default", "cshuffle"]}, - "scheduler": {"values": ["intrawave", "interwave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json deleted file mode 100644 index 498ef9fa33a1..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "problem": { - "description": "Basic functionality validation with moderate problem sizes" - }, - "test_params": { - "problem_sizes": [ - {"m": 256, "n": 256, "k": 128, "split_k": 1}, - {"m": 512, "n": 256, "k": 256, "split_k": 1}, - {"m": 256, "n": 512, "k": 256, "split_k": 1} - ] - }, - "tile_config": { - "tile_m": {"values": [128]}, - "tile_n": {"values": [128]}, - "tile_k": {"values": [64]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [16]}, - "warp_tile_n": {"values": [16]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["compv3", "compv4"]}, - "epilogue": {"values": ["default"]}, - "scheduler": {"values": ["intrawave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json deleted file mode 100644 index d0d9f99a0cc7..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "problem": { - "description": "Configuration optimized for small data types (fp8, fp16, bf16) with larger warp tiles" - }, - "test_params": { - "problem_sizes": [ - {"m": 512, "n": 512, "k": 256, "split_k": 1}, - {"m": 1024, "n": 512, "k": 512, "split_k": 1}, - {"m": 512, "n": 1024, "k": 512, "split_k": 1}, - {"m": 1024, "n": 1024, "k": 256, "split_k": 1} - ] - }, - "tile_config": { - "tile_m": {"values": [128]}, - "tile_n": {"values": [128]}, - "tile_k": {"values": [32]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [32]}, - "warp_tile_n": {"values": [32]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["compv3"]}, - "epilogue": {"values": ["default"]}, - "scheduler": {"values": ["intrawave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py b/projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py deleted file mode 100644 index 48ec8dba8352..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - - -import json -import argparse -import os -from pathlib import Path - - -def extract_test_params(config_file, output_file): - """Extract test parameters from config JSON and write to output file""" - - # Read config file - with open(config_file, "r") as f: - config = json.load(f) - - # Extract test parameters - test_params = [] - if "test_params" in config and "problem_sizes" in config["test_params"]: - test_params = config["test_params"]["problem_sizes"] - else: - # Default test parameters if none specified - test_params = [ - {"m": 256, "n": 256, "k": 128, "split_k": 1}, - {"m": 256, "n": 256, "k": 1024, "split_k": 1}, - {"m": 256, "n": 512, "k": 512, "split_k": 1}, - {"m": 512, "n": 256, "k": 512, "split_k": 1}, - ] - - # Write to output file in C++ format - output_dir = Path(output_file).parent - output_dir.mkdir(parents=True, exist_ok=True) - - with open(output_file, "w") as f: - f.write("// Generated test parameters for this configuration\n") - f.write("// This file is auto-generated during CMake configuration\n\n") - f.write("static const std::vector CONFIG_TEST_PARAMS = {\n") - - for i, params in enumerate(test_params): - comma = "," if i < len(test_params) - 1 else "" - f.write( - f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n" - ) - - f.write("};\n") - - print( - f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}" - ) - - -def main(): - parser = argparse.ArgumentParser( - description="Extract test parameters from config JSON" - ) - parser.add_argument("--config_file", required=True, help="Input config JSON file") - parser.add_argument( - "--output_file", required=True, help="Output test parameters file" - ) - - args = parser.parse_args() - - if not os.path.exists(args.config_file): - print(f"Error: Config file not found: {args.config_file}") - return 1 - - extract_test_params(args.config_file, args.output_file) - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp b/projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp deleted file mode 100644 index e44e8c4182ac..000000000000 --- a/projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -/** - * @file test_gemm_simple.cpp - * @brief Unit tests for GEMM kernels generated by gemm_instance_builder - * - * This test includes kernels generated during CMake configuration by - * gemm_instance_builder.py and tests them with problem sizes extracted - * from the corresponding JSON configuration files. - */ - -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "tile_engine/ops/gemm/gemm_common.hpp" - -// The kernel header is included via compile command line with -include flag -// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types - -// Adaptive error threshold calculation matching tile_engine's implementation -template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // Calculate error due to split_k accumulation - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); -} - -/// @brief Function to compare the results of the device and host computations (from tile_engine) -template -bool compare_results(std::string instanceName, - ck_tile::index_t K, - ck_tile::index_t kbatch, - ck_tile::HostTensor& c_m_n_dev_result, - ck_tile::HostTensor& c_m_n_host_result) -{ - const float max_accumulated_value = - *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_result, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); - - std::cout << "For " << instanceName << " Relative error threshold is " - << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " - << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; - - return pass; -} - -// Test parameter structure for matrix dimensions and split_k values -struct GemmTestParams -{ - int m, n, k, split_k; -}; - -// Include config-specific test parameters (after GemmTestParams struct is defined) -#ifdef GEMM_TEST_PARAMS_HPP -#include GEMM_TEST_PARAMS_HPP -#endif - -class GemmTileEngineTest : public ::testing::TestWithParam -{ - protected: - void SetUp() override - { - auto params = GetParam(); - m_ = params.m; - n_ = params.n; - k_ = params.k; - split_k_ = params.split_k; - - // Calculate strides (following tile_engine pattern) - if constexpr(std::is_same_v) - { - stride_a_ = k_; - } - else - { - stride_a_ = m_; - } - - if constexpr(std::is_same_v) - { - stride_b_ = n_; - } - else - { - stride_b_ = k_; - } - - if constexpr(std::is_same_v) - { - stride_c_ = n_; - } - else - { - stride_c_ = m_; - } - } - - // Test dimensions - int m_, n_, k_, split_k_; - int stride_a_, stride_b_, stride_c_; -}; - -TEST_P(GemmTileEngineTest, BasicFunctionality) -{ - // Get tensor layouts from generated kernel - const ALayout layout_a = ALayout{}; - const BLayout layout_b = BLayout{}; - const CLayout layout_c = CLayout{}; - - // Use split_k from test parameters - int split_k = split_k_; - int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a)); - int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b)); - int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c)); - - // Create host tensors with proper descriptors - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); - ck_tile::HostTensor c_m_n_host_result( - ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); - - // Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine) - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); - - // Allocate GPU device memory - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - - // Copy data to device and zero output buffer - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - - // Calculate reference result on host for verification - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_result); - - // Create GEMM kernel arguments - ck_tile::GemmHostArgs gemm_args(a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - split_k, - m_, - n_, - k_, - stride_a_calc, - stride_b_calc, - stride_c_calc); - - // Configure kernel execution for maximum speed (no timing, no debug output) - ck_tile::stream_config stream_config{nullptr, // stream - false, // time_kernel (disable timing for speed) - 0, // log_level (disable debug output) - 0, // n_warmup - 1, // n_repeat - false, // is_gpu_timer (unused when time_kernel=false) - false, // flush_cache - 1}; // rotating_count - - // Launch the generated kernel (no timing overhead for fastest execution) - try - { - SelectedKernel::launch(gemm_args, stream_config); - // Kernel launched successfully if no exception thrown - } - catch(const std::exception& e) - { - std::string error_msg(e.what()); - // If arguments not supported, skip the test (configuration validation failure, not a bug) - if(error_msg.find("Arguments not supported") != std::string::npos) - { - GTEST_SKIP() << "Configuration not supported: " << e.what(); - } - else - { - FAIL() << "Kernel launch failed: " << e.what(); - } - } - - // Copy result back from device - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - - // Verify results using tile_engine's adaptive error thresholds - bool verification_passed = compare_results( - KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result); - - EXPECT_TRUE(verification_passed) << "GEMM result verification failed"; -} - -TEST_P(GemmTileEngineTest, KernelInfo) -{ - // Simple test to verify kernel information is available - EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty"; - - std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; - std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_ - << std::endl; -} - -// Use config-specific test parameters (included via compile flags) -// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file -INSTANTIATE_TEST_SUITE_P(GemmVerification, - GemmTileEngineTest, - ::testing::ValuesIn(CONFIG_TEST_PARAMS), - [](const ::testing::TestParamInfo& param_info) { - return std::to_string(param_info.param.m) + "x" + - std::to_string(param_info.param.n) + "x" + - std::to_string(param_info.param.k) + "_splitk" + - std::to_string(param_info.param.split_k); - }); diff --git a/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt b/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt index b50a6790105a..c7f6e48930ad 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt +++ b/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt @@ -15,7 +15,7 @@ if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") if(_te_budget GREATER 0) # Detect active ops from their DATATYPE variables set(_active_ops "") - foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) + foreach(_op gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) string(TOUPPER ${_op} _OP_UPPER) if(NOT "${${_OP_UPPER}_DATATYPE}" STREQUAL "") list(APPEND _active_ops ${_op}) @@ -45,7 +45,7 @@ if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") message(STATUS "Sampling budget allocation:\n${_alloc_output}") # Read per-op allocations (only if not already overridden) - foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) + foreach(_op gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) string(TOUPPER ${_op} _OP_UPPER) if("${${_OP_UPPER}_MAX_INSTANCES}" STREQUAL "") if(EXISTS "${_alloc_dir}/${_op}_budget.txt") @@ -73,7 +73,6 @@ if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") endif() endif() -add_subdirectory(gemm_universal EXCLUDE_FROM_ALL) add_subdirectory(gemm_multi_d EXCLUDE_FROM_ALL) add_subdirectory(gemm_preshuffle EXCLUDE_FROM_ALL) add_subdirectory(grouped_gemm EXCLUDE_FROM_ALL) diff --git a/projects/composablekernel/tile_engine/ops/gemm/README.md b/projects/composablekernel/tile_engine/ops/gemm/README.md index 5e0bae70806d..3c497da7cef0 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/README.md +++ b/projects/composablekernel/tile_engine/ops/gemm/README.md @@ -6,6 +6,7 @@ The CK Tile Engine GEMM module provides a comprehensive system for generating, b ## Table of Contents +0. [Dispatcher Bridge Workflow](#dispatcher-bridge-workflow) 1. [Build System Architecture](#build-system-architecture) 2. [Build Instructions](#build-instructions) 3. [Running Benchmarks](#running-benchmarks) @@ -16,6 +17,145 @@ The CK Tile Engine GEMM module provides a comprehensive system for generating, b 8. [Troubleshooting](#troubleshooting) 9. [Performance Tips](#performance-tips) +## Dispatcher Bridge Workflow + +The **Dispatcher bridge** is the recommended path for sweeping and benchmarking +GEMM kernels. Instead of building monolithic or per-kernel executables through +CMake, Tile Engine expands a sweep config into shared `GemmKernelConfig` objects +and hands them to the Dispatcher, which codegens and compiles each into its own +`.so`. The kernel name produced by the bridge is byte-for-byte identical to the +codegen `KERNEL_NAME`, so the bridge runs exactly the same kernels the native +Tile Engine does — it only swaps the harness. + +### Scripts + +| Script | Role | +|---|---| +| `gemm_full_benchmark.py` | Driver: compile (Phase 1) → load problems (Phase 2) → benchmark across all visible GPUs (Phase 3). | +| `run_one_gemm_kernel.py` | Disposable worker: loads one `.so` in an isolated subprocess and times it. A GPU fault kills only the worker. | + +### Folder layout + +The bridged regular-GEMM path follows the same op-root convention as the merged +`fmha/` and `grouped_conv/` bridges — driver + worker + a flat `configs/` at the +op root: + +``` +gemm/ +├── gemm_full_benchmark.py # bridge driver (op root) +├── run_one_gemm_kernel.py # disposable per-kernel worker (op root) +├── configs/ # bridged gemm_universal sweep configs (flat) +├── gemm_instance_builder.py # shared generator for the non-bridged variants +├── gemm_benchmark.{py,hpp}, gemm_common.hpp, gemm_profiler.hpp # shared harness +├── gemm_multi_d/ gemm_preshuffle/ grouped_gemm/ # legacy variants +└── README.md +``` + +`configs/` ships example sweep configs: + +- `default_ci_config.json` — small CI-sized sweep (the driver's default when no + config is passed). +- `default_config.json` — full sweep. +- `user_provided_config.json` — scratch space for custom sweeps. +- `example_problems.json` — example M/N/K problem set (used when `--problems` + is omitted). + +> The JSON used by **nightly** tests is intended to drop into the same +> `configs/` directory and be selected with a positional config — no driver +> changes needed. + +The not-yet-bridged variants (`gemm_multi_d/`, `gemm_preshuffle/`, +`grouped_gemm/`) keep their own per-variant `configs/` directories; the driver +selects them with `--variant`. + +### Running + +```bash +cd tile_engine/ops/gemm + +# Default: gemm_universal variant, its CI sweep + example problems, +# auto-detect and use all visible GPUs. +python gemm_full_benchmark.py + +# Full sweep, fp16/rcr, restricted to 4 GPUs, custom output: +python gemm_full_benchmark.py --variant gemm_universal \ + configs/default_config.json \ + --dtype fp16 --layout rcr --devices 4 --csv gemm_results.csv + +# Specific GPU ids and a custom problem file: +python gemm_full_benchmark.py --devices 0,2,5 \ + --problems configs/example_problems.json + +# Correctness mode: check every kernel against an fp32 numpy reference. +python gemm_full_benchmark.py --verify --max-kernels 8 +``` + +### Liveness vs correctness (`--verify`) + +By default a measurement is reported `OK` purely on **liveness** — the kernel +ran and produced a non-zero output (`ZERO` otherwise). It is *not* a correctness +check: a numerically wrong but non-zero result still reads `OK`. Pass `--verify` +to have each worker compare its output against an fp32 numpy reference +(`A @ B`) using the global relative metric `max|out - ref| / max|ref|`. With +`--verify`, results read `VERIFY` (within `--verify-tol`, default `2e-2`) or +`MISMATCH` (counted as a failure), and the `max_rel` / `verified` columns are +populated in the CSV. This gives self-contained per-kernel confidence; the +broader numeric parity against native Tile Engine remains a separate task. + +### Multi-GPU parallelism + +Phase 3 fans the `(kernel × problem)` work out across **every visible GPU** in +parallel. One worker thread per device pulls batches from a shared queue and +spawns a disposable subprocess pinned with `HIP_VISIBLE_DEVICES`, so an N-GPU box +benchmarks roughly N× faster while keeping per-batch fault isolation. Devices are +auto-detected (`HIP_VISIBLE_DEVICES`, then `rocm-smi`/`amd-smi`); override with +`--devices`. This supersedes the serial-GPU design inherited from grouped_conv. + +### Supported surface + +| Axis | Supported | +|---|---| +| dtype | `fp16`, `bf16` | +| layout | `rcr`, `rrr`, `crr`, `ccr` (row-major C only — ck_tile rejects column-major C at build) | + +### Variant scope + +The bridge is **one shared, variant-aware driver** (`gemm_full_benchmark.py` + +`run_one_gemm_kernel.py`), not a per-variant copy of the driver. The bridged +regular-GEMM path (`gemm_universal`) uses the op-root `configs/`; `--variant` +selects a not-yet-bridged variant's own `configs/` subdirectory. + +What that means for this PR: + +- **Only `gemm_universal` is wired and validated through the bridge here.** It is + the foundation variant; the dispatcher codegen path is exercised and parity- + checked for it alone. +- The `gemm_multi_d/`, `gemm_preshuffle/`, and `grouped_gemm/` `configs/` + directories are **scaffolding** that follows the per-variant convention so the + layout is ready. `--variant` will select them, but the bridge does **not** yet + produce correct kernels for those variants on this PR — do not treat their + presence as working support. +- Grouped GEMM and stream-K go through **separate bridge efforts** (stream-K in + #8136, grouped GEMM on its own branch), not this PR. + +### Removal note + +The legacy regular-GEMM standalone build path has been **removed**, and the +`gemm_universal/` folder is gone entirely. The per-config benchmark generator and +driver (`gemm_universal_instance_builder.py`, `gemm_universal_benchmark.py`, +`gemm_universal_benchmark*.{cpp,hpp}`, and `gemm_universal/CMakeLists.txt`) no +longer exist; its sweep configs were promoted to the op-root `configs/` directory +(matching the `fmha/` and `grouped_conv/` bridge convention) and are consumed by +the bridge. Regular GEMM now runs exclusively through the Dispatcher bridge +workflow above (`gemm_full_benchmark.py` / `run_one_gemm_kernel.py`). The other +variants (`gemm_multi_d/`, `gemm_preshuffle/`, `grouped_gemm/`) still use the +shared `gemm_instance_builder.py` generator. + +The build-system, build-instruction, and benchmark-execution sections below +describe that removed standalone path and are retained only as historical +reference for the non-bridged variants; the `benchmark_gemm_universal_*` targets +they mention are no longer produced. + ## Build System Architecture ### Individual Kernel Compilation (New Approach) @@ -171,8 +311,13 @@ The system uses JSON configuration files to specify kernel parameters: ### Python Scripts -#### gemm_universal_instance_builder.py -**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files. +#### gemm_instance_builder.py +**Purpose**: Shared kernel instance generator used by the non-bridged variants +(`gemm_multi_d`, `gemm_preshuffle`, `grouped_gemm`). Creates C++ kernel +implementations based on configuration files. + +> The regular-GEMM subclass `gemm_universal/gemm_universal_instance_builder.py` +> has been removed; regular GEMM now goes through the Dispatcher bridge. **Key Features**: - Generates individual kernel header files for separate compilation @@ -180,16 +325,6 @@ The system uses JSON configuration files to specify kernel parameters: - Validates tile configurations for correctness - Creates CMake integration files -**Usage**: -```bash -python gemm_universal_instance_builder.py \ - --working_path ./generated \ - --datatype fp16 \ - --layout rcr \ - --config_json configs/user_provided_config.json \ - --gen_all_individual -``` - #### gemm_instance_builder_parallel.py **Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations. @@ -225,14 +360,6 @@ python test_validation.py - Trait combination validation - Full tile configuration validation -#### gemm_universal_benchmark.py -**Purpose**: Python script for running and analyzing GEMM benchmarks. - -**Features**: -- Automated benchmark execution -- Performance data collection -- Result analysis and reporting - #### json_config.py **Purpose**: Configuration file parsing and management. diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_ci_config.json b/projects/composablekernel/tile_engine/ops/gemm/configs/default_ci_config.json similarity index 98% rename from projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_ci_config.json rename to projects/composablekernel/tile_engine/ops/gemm/configs/default_ci_config.json index 38376a410b01..a2b83334245e 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_ci_config.json +++ b/projects/composablekernel/tile_engine/ops/gemm/configs/default_ci_config.json @@ -32,17 +32,17 @@ }, "warp_tile_m": { "values": [ - 16 + 32 ] }, "warp_tile_n": { "values": [ - 16 + 32 ] }, "warp_tile_k": { "values": [ - 32 + 16 ] } }, diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_config.json b/projects/composablekernel/tile_engine/ops/gemm/configs/default_config.json similarity index 100% rename from projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_config.json rename to projects/composablekernel/tile_engine/ops/gemm/configs/default_config.json diff --git a/projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json b/projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json new file mode 100644 index 000000000000..4be0c5a82379 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json @@ -0,0 +1,9 @@ +{ + "problems": [ + {"M": 512, "N": 512, "K": 512}, + {"M": 1024, "N": 1024, "K": 1024}, + {"M": 2048, "N": 2048, "K": 2048}, + {"M": 1024, "N": 512, "K": 256}, + {"M": 4096, "N": 4096, "K": 4096} + ] +} diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json b/projects/composablekernel/tile_engine/ops/gemm/configs/user_provided_config.json similarity index 100% rename from projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json rename to projects/composablekernel/tile_engine/ops/gemm/configs/user_provided_config.json diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py new file mode 100644 index 000000000000..10f97328081f --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python3 +"""Full GEMM benchmark sweep driven through the Dispatcher bridge. + +Phases: + Phase 1: Compile all kernels (parallel, returns .so paths only -- no GPU) + Phase 2: Load problems (M, N, K shapes) + Phase 3: Benchmark via subprocess isolation, distributed across all visible + GPUs (one device-pinned worker per GPU, batched, fault-isolated) + +Tile Engine generates NO binaries here: it expands its sweep config into shared +``GemmKernelConfig`` objects and hands them to the dispatcher, which codegens + +compiles each into a .so. Each kernel runs in a disposable worker subprocess so +a GPU fault (or ctypes' inability to unload a .so) takes down only one worker. + +Unlike the serial-GPU design inherited from grouped_conv, Phase 3 here fans the +work out across every visible GPU in parallel: each device runs its own stream of +disposable worker subprocesses pinned with ``HIP_VISIBLE_DEVICES``, so an N-GPU +box benchmarks roughly N times faster while keeping per-batch fault isolation. + +Examples: + # Default: gemm_universal variant, its CI sweep config + example problems, + # auto-detect and use all visible GPUs. + python gemm_full_benchmark.py + + # Explicit variant + full sweep config on 4 GPUs: + python gemm_full_benchmark.py --variant gemm_universal \ + configs/default_config.json --devices 4 --csv out.csv + +When no config is given the driver uses the chosen variant's +``configs/default_ci_config.json`` (a small CI-sized sweep); +``configs/default_config.json`` is the full sweep, and the JSON used by nightly +tests is intended to drop into the same ``configs/`` directory. +""" + +import argparse +import csv +import json +import os +import queue +import re +import subprocess +import sys +import threading +import time +from pathlib import Path + +_THIS_DIR = Path(__file__).resolve().parent +_DISPATCHER_ROOT = _THIS_DIR.parents[2] / "dispatcher" +sys.path.insert(0, str(_DISPATCHER_ROOT / "python")) +sys.path.insert(0, str(_THIS_DIR)) + +from gemm_utils import setup_multiple_gemm_dispatchers, expand_sweep # noqa: E402 + +# Config layout. The bridged regular-GEMM path (gemm_universal) keeps its sweep +# configs in this op's flat ``configs/`` directory (matching the fmha/grouped_conv +# bridge convention): default_ci_config.json (small CI sweep), default_config.json +# (full sweep), user_provided_config.json, example_problems.json. The other, +# not-yet-bridged variants still live in their own per-variant ``configs/`` dirs; +# they are registered so ``--variant`` can select them once their bridge lands. +VARIANT_CONFIGS = { + "gemm_universal": "configs", + "gemm_multi_d": "gemm_multi_d/configs", + "gemm_preshuffle": "gemm_preshuffle/configs", + "grouped_gemm": "grouped_gemm/configs", +} +DEFAULT_VARIANT = "gemm_universal" +CI_CONFIG_NAME = "default_ci_config.json" +EXAMPLE_PROBLEMS_NAME = "example_problems.json" + +# Fallback problem set if a variant ships no example_problems.json. +DEFAULT_PROBLEMS = [ + {"M": 1024, "N": 1024, "K": 1024}, + {"M": 2048, "N": 2048, "K": 2048}, + {"M": 4096, "N": 4096, "K": 4096}, + {"M": 257, "N": 257, "K": 257}, +] + +SUPPORTED_DTYPES = ("fp16", "bf16") +# Row-major C only: ck_tile's universal GEMM rejects column-major C at build. +SUPPORTED_LAYOUTS = ("rcr", "rrr", "crr", "ccr") + + +def detect_devices(): + """Return a list of visible GPU id strings (best-effort).""" + env = os.environ.get("HIP_VISIBLE_DEVICES") or os.environ.get( + "CUDA_VISIBLE_DEVICES" + ) + if env: + ids = [d.strip() for d in env.split(",") if d.strip() != ""] + if ids: + return ids + try: + out = subprocess.check_output( + ["rocm-smi", "--showid"], stderr=subprocess.DEVNULL, text=True + ) + ids = sorted(set(re.findall(r"GPU\[(\d+)\]", out)), key=int) + if ids: + return ids + except Exception: + pass + try: + out = subprocess.check_output( + ["amd-smi", "list"], stderr=subprocess.DEVNULL, text=True + ) + ids = re.findall(r"^GPU:\s*(\d+)", out, re.MULTILINE) + if ids: + return ids + except Exception: + pass + return ["0"] + + +def resolve_devices(spec): + """Resolve --devices into a concrete list of device id strings. + + spec is None (auto: all visible), an int count, or a comma-list of ids. + A bare digit is a *count*, not an id; to target one specific id use the + comma form, e.g. "5,". + """ + detected = detect_devices() + if spec is None: + return detected + spec = str(spec).strip() + if "," in spec: + return [s.strip() for s in spec.split(",") if s.strip() != ""] + if spec.isdigit(): + n = int(spec) + if n <= 0: + return detected + # Treat a bare integer as a device *count*: take the first n detected + # ids, falling back to a plain 0..n-1 range if detection under-reports. + # To target one specific device id, use the comma form (e.g. "5,"). + return detected[:n] if len(detected) >= n else [str(i) for i in range(n)] + return [spec] + + +def resolve_configs(args): + """Resolve positional configs -> concrete list of config paths.""" + if args.configs: + return args.configs + cfg = _THIS_DIR / VARIANT_CONFIGS[args.variant] / CI_CONFIG_NAME + return [str(cfg)] + + +def load_problems(path, variant): + if path: + with open(path) as f: + data = json.load(f) + return data["problems"] if isinstance(data, dict) else data + example = _THIS_DIR / VARIANT_CONFIGS[variant] / EXAMPLE_PROBLEMS_NAME + if example.exists(): + with open(example) as f: + data = json.load(f) + return data["problems"] if isinstance(data, dict) else data + return DEFAULT_PROBLEMS + + +def _run_batch_on_device(device_id, unit, args, worker_path, base_env): + """Run one (problem, kernel-batch) unit in a device-pinned subprocess. + + Returns (rows, lines, n_fail) where rows are dicts ready for the CSV writer, + lines are formatted strings to print, and n_fail counts failures. + """ + prob_idx, prob_dict, batch = unit + M, N, K = prob_dict["M"], prob_dict["N"], prob_dict["K"] + + items = [ + {"so_path": str(lib), "problem": prob_dict, "kernel_name": cfg.name} + for _, cfg, lib in batch + ] + payload = json.dumps( + {"items": items, "verify": args.verify, "verify_tol": args.verify_tol} + ) + + env = base_env.copy() + env["HIP_VISIBLE_DEVICES"] = str(device_id) + + rows, lines, n_fail = [], [], 0 + proc = None + try: + proc = subprocess.Popen( + [sys.executable, str(worker_path)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env=env, + ) + stdout_bytes, _ = proc.communicate( + input=payload.encode("utf-8"), + timeout=args.kernel_timeout * len(batch), + ) + + reported = set() + for line in stdout_bytes.decode("utf-8").strip().split("\n"): + if not line: + continue + try: + result = json.loads(line) + except json.JSONDecodeError: + lines.append(f" [gpu{device_id}] Warning: bad result line: {line[:50]}") + n_fail += 1 + continue + bidx = result.get("idx", 0) + _, cfg, _ = batch[bidx] + reported.add(bidx) + if result.get("ok", False): + status = "OK" if result.get("non_zero", 0) > 0 else "ZERO" + mismatch = False + if args.verify and "verified" in result: + if result["verified"]: + status = "VERIFY" + else: + status = "MISMATCH" + mismatch = True + extra = ( + f" rel={result['max_rel']:.2e}" if "max_rel" in result else "" + ) + lines.append( + f" [gpu{device_id}] {cfg.name:<58} {result['ms']:>10.3f} " + f"{result['tflops']:>10.2f} {status:>8}{extra}" + ) + rows.append( + { + "kernel": cfg.name, + "problem_idx": prob_idx, + "M": M, + "N": N, + "K": K, + "device": device_id, + "latency_ms": result["ms"], + "tflops": result["tflops"], + "non_zero": result.get("non_zero", 0), + "max_rel": result.get("max_rel", ""), + "verified": result.get("verified", ""), + } + ) + if mismatch: + n_fail += 1 + else: + lines.append(f" [gpu{device_id}] {cfg.name:<58} FAILED") + lines.append(f" Error: {result.get('error', 'unknown')[:100]}") + n_fail += 1 + + missing = set(range(len(batch))) - reported + if missing or proc.returncode != 0: + if proc.returncode != 0: + lines.append(f" [gpu{device_id}] worker exited code {proc.returncode}") + for idx in sorted(missing): + _, cfg, _ = batch[idx] + lines.append(f" [gpu{device_id}] {cfg.name:<58} MISSING (crash)") + n_fail += len(missing) + + except subprocess.TimeoutExpired: + lines.append(f" [gpu{device_id}] batch timeout ({len(batch)} kernels)") + try: + proc.kill() + proc.communicate(timeout=5) + except Exception: + pass + n_fail += len(batch) + except Exception as e: + lines.append(f" [gpu{device_id}] batch error: {e}") + try: + if proc and proc.poll() is None: + proc.kill() + except Exception: + pass + n_fail += len(batch) + + return rows, lines, n_fail + + +def main(): + parser = argparse.ArgumentParser(description="GEMM Benchmark Sweep (via Dispatcher)") + parser.add_argument( + "configs", + nargs="*", + help="TE sweep config JSON files (default: variant's default_ci_config.json)", + ) + parser.add_argument( + "--variant", + default=DEFAULT_VARIANT, + choices=tuple(VARIANT_CONFIGS), + help="GEMM variant (selects the configs/ directory)", + ) + parser.add_argument("--arch", default="gfx942") + parser.add_argument( + "--dtype", + default="fp16", + choices=SUPPORTED_DTYPES, + help=f"Input dtype (supported: {', '.join(SUPPORTED_DTYPES)})", + ) + parser.add_argument( + "--layout", + default="rcr", + choices=SUPPORTED_LAYOUTS, + help=f"A/B/C layout (supported: {', '.join(SUPPORTED_LAYOUTS)})", + ) + parser.add_argument("--problems", default=None, help="JSON file of M,N,K problems") + parser.add_argument("--csv", type=str, default="gemm_results.csv") + parser.add_argument("--workers", type=int, default=8, help="Parallel build workers") + parser.add_argument( + "--devices", + default=None, + help="GPUs to use: int count (e.g. 4) or comma-list of ids (e.g. 0,2,5); " + "for one specific id use the comma form (e.g. 5,) since a bare digit is " + "a count; default auto-detects all visible", + ) + parser.add_argument( + "--batch-size", + type=int, + default=20, + help="Kernels per subprocess (overhead vs fault isolation)", + ) + parser.add_argument( + "--kernel-timeout", type=int, default=30, help="Per-kernel timeout (s)" + ) + parser.add_argument( + "--max-kernels", type=int, default=0, help="Limit to first N kernels (0=all)" + ) + parser.add_argument( + "--verify", + action="store_true", + help="Check each kernel's output against an fp32 numpy reference " + "(global max|out-ref|/max|ref|); a mismatch counts as a failure", + ) + parser.add_argument( + "--verify-tol", + type=float, + default=2e-2, + help="Relative tolerance for --verify (default 2e-2, suits fp16)", + ) + args = parser.parse_args() + + config_paths = resolve_configs(args) + devices = resolve_devices(args.devices) + + # ======================================================================== + # Phase 1: Compile kernels (parallel, no GPU) + # ======================================================================== + print(f"\n{'=' * 80}") + print("Phase 1: Compile kernels") + print(f"{'=' * 80}") + print(f" Variant: {args.variant}") + print(f" Configs: {', '.join(config_paths)}") + + all_configs = [] + for cfg_path in config_paths: + all_configs.extend( + expand_sweep(cfg_path, args.arch, dtype=args.dtype, layout=args.layout) + ) + + if args.max_kernels > 0: + all_configs = all_configs[: args.max_kernels] + + print(f" Expanded configs: {len(all_configs)}") + print(f" Build workers: {args.workers}") + + t0 = time.perf_counter() + # CRITICAL: returns Path objects only, does NOT load any .so. + lib_paths = setup_multiple_gemm_dispatchers( + all_configs, verbose=True, max_workers=args.workers + ) + build_time = time.perf_counter() - t0 + + built_kernels = [ + (cfg, lib) for cfg, lib in zip(all_configs, lib_paths) if lib is not None + ] + + # Dedupe by .so path (distinct configs can map to the same physical kernel). + seen_libs = set() + unique_kernels = [] + duplicate_count = 0 + for cfg, lib in built_kernels: + lib_key = str(lib.resolve()) + if lib_key not in seen_libs: + seen_libs.add(lib_key) + unique_kernels.append((cfg, lib)) + else: + duplicate_count += 1 + built_kernels = unique_kernels + + print( + f"\n Built {len(all_configs)} configs -> {len(built_kernels)} unique kernels " + f"({duplicate_count} duplicates filtered) in {build_time:.0f}s" + ) + + if not built_kernels: + print(" ERROR: No kernels built successfully") + return 1 + + # ======================================================================== + # Phase 2: Load problems + # ======================================================================== + print(f"\n{'=' * 80}") + print("Phase 2: Load test problems") + print(f"{'=' * 80}") + + problems = load_problems(args.problems, args.variant) + print(f" Problems: {len(problems)}") + print( + f" Total measurements: {len(built_kernels)} x {len(problems)} = " + f"{len(built_kernels) * len(problems)}" + ) + + # ======================================================================== + # Phase 3: Benchmark across all visible GPUs (subprocess isolation, batched) + # ======================================================================== + print(f"\n{'=' * 80}") + print("Phase 3: Benchmark (multi-GPU, subprocess isolation, batched)") + print(f"{'=' * 80}") + print(f" Devices: {len(devices)} -> {', '.join(devices)}") + print(f" Batch size: {args.batch_size} kernels per subprocess") + print(f" Timeout: {args.kernel_timeout}s per kernel\n") + + csv_path = Path(args.csv) + csv_fields = [ + "kernel", + "problem_idx", + "M", + "N", + "K", + "device", + "latency_ms", + "tflops", + "non_zero", + "max_rel", + "verified", + ] + csv_file = open(csv_path, "w", newline="") + writer = csv.DictWriter(csv_file, fieldnames=csv_fields) + writer.writeheader() + + worker_path = _THIS_DIR / "run_one_gemm_kernel.py" + base_env = os.environ.copy() + base_env["GEMM_PYPATH"] = os.pathsep.join( + [str(_DISPATCHER_ROOT / "python"), str(_THIS_DIR)] + ) + + # Build a single work queue of (prob_idx, prob_dict, kernel-batch) units and + # fan them out across device-pinned worker threads. + work_q = queue.Queue() + for prob_idx, prob in enumerate(problems): + prob_dict = {"M": int(prob["M"]), "N": int(prob["N"]), "K": int(prob["K"])} + for start in range(0, len(built_kernels), args.batch_size): + end = min(start + args.batch_size, len(built_kernels)) + batch = [ + (start + j, cfg, lib) + for j, (cfg, lib) in enumerate(built_kernels[start:end]) + ] + work_q.put((prob_idx, prob_dict, batch)) + + io_lock = threading.Lock() + stats = {"measurements": 0, "failures": 0} + bench_t0 = time.perf_counter() + + def device_thread(device_id): + while True: + try: + unit = work_q.get_nowait() + except queue.Empty: + return + rows, lines, n_fail = _run_batch_on_device( + device_id, unit, args, worker_path, base_env + ) + with io_lock: + for ln in lines: + print(ln) + for row in rows: + writer.writerow(row) + csv_file.flush() + stats["measurements"] += len(rows) + stats["failures"] += n_fail + work_q.task_done() + + threads = [ + threading.Thread(target=device_thread, args=(d,), daemon=True) for d in devices + ] + for t in threads: + t.start() + for t in threads: + t.join() + + bench_time = time.perf_counter() - bench_t0 + csv_file.close() + + # ======================================================================== + # Summary + # ======================================================================== + print(f"\n{'=' * 80}") + print("BENCHMARK COMPLETE") + print(f"{'=' * 80}") + print(f" Build time: {build_time:.0f}s") + print(f" Benchmark time: {bench_time:.0f}s") + print(f" Total time: {build_time + bench_time:.0f}s") + print(f" Devices used: {len(devices)}") + print(f" Successful measurements: {stats['measurements']}") + print(f" Failed measurements: {stats['failures']}") + print(f" Output: {csv_path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt deleted file mode 100644 index e0624b7067b2..000000000000 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)") -set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)") -set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") -set(GEMM_UNIVERSAL_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)") -option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF) - -# Store the directory path for use in functions -set(GEMM_UNIVERSAL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) - -# Function to create individual GEMM Universal targets -function(create_individual_gemm_universal_target datatype layout trait tile_config config_json) - # Use the parent scope GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL variable - if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping individual GEMM Universal target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") - return() - endif() - - # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k - # First split by underscore to get three groups - string(REPLACE "_" ";" config_groups ${tile_config}) - list(GET config_groups 0 tile_dims) # e.g., 256x256x32 - list(GET config_groups 1 warp_dims) # e.g., 4x1x1 - list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 - - # Parse tile dimensions - string(REPLACE "x" ";" tile_parts ${tile_dims}) - list(GET tile_parts 0 tile_m) - list(GET tile_parts 1 tile_n) - list(GET tile_parts 2 tile_k) - - # Parse warp dimensions - string(REPLACE "x" ";" warp_parts ${warp_dims}) - list(GET warp_parts 0 warp_m) - list(GET warp_parts 1 warp_n) - list(GET warp_parts 2 warp_k) - - # Parse warp tile dimensions - string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) - list(GET warp_tile_parts 0 warp_tile_m) - list(GET warp_tile_parts 1 warp_tile_n) - list(GET warp_tile_parts 2 warp_tile_k) - - set(target_name "benchmark_gemm_universal_${datatype}_${layout}_${trait}_${tile_config}") - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - - # Generate the single instance header for this kernel - set(instance_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") - - # Add custom command to generate the header file at build time - add_custom_command( - OUTPUT ${instance_header} - COMMAND ${Python3_EXECUTABLE} ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${config_json} - --gen_single - --kernel_name "gemm_universal_${datatype}_${layout}_${trait}_${tile_config}" - --tile_config "${tile_config}" - --trait_combo "${trait}" - --gpu_target "${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}" - DEPENDS ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py ${config_json} - COMMENT "Generating ${instance_header}" - ) - - # Create the executable - add_executable(${target_name} - EXCLUDE_FROM_ALL - ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_benchmark_single.cpp - ${instance_header} - ) - - # Set GPU architectures - set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}) - - # Set compile definitions - target_compile_definitions(${target_name} PRIVATE - GEMM_UNIVERSAL_SINGLE_INSTANCE_HPP="${instance_header}" - ) - - # Include directories - target_include_directories(${target_name} PRIVATE - ${GEMM_UNIVERSAL_SOURCE_DIR} - ${working_path} - ) - - # Compile options - target_compile_options(${target_name} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress - -include ${instance_header} - ) - - # Add to collection targets - add_dependencies(benchmark_gemm_universal_all ${target_name}) - add_dependencies(benchmark_gemm_universal_${datatype} ${target_name}) - add_dependencies(benchmark_gemm_universal_${layout} ${target_name}) - add_dependencies(benchmark_gemm_universal_${datatype}_${layout} ${target_name}) - - # Add to trait-specific targets - string(REPLACE "_" ";" trait_parts ${trait}) - list(GET trait_parts 0 pipeline) - list(GET trait_parts 1 epilogue) - list(GET trait_parts 2 scheduler) - - add_dependencies(benchmark_gemm_universal_${pipeline}_pipeline ${target_name}) - add_dependencies(benchmark_gemm_universal_${epilogue}_epilogue ${target_name}) - add_dependencies(benchmark_gemm_universal_${scheduler}_scheduler ${target_name}) -endfunction() - -# Function to build individual GEMM Universal targets -function(build_individual_gemm_universal_targets datatype layout) - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - - # Choose config file - # Priority order: - # 1. Environment variable GEMM_UNIVERSAL_CONFIG_FILE - # 2. CMake variable GEMM_UNIVERSAL_CONFIG_FILE - # 3. Default based on layout - - # Check environment variable first - if(DEFINED ENV{GEMM_UNIVERSAL_CONFIG_FILE} AND NOT "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "") - set(config_filename "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}") - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") - message(VERBOSE " Using config from environment variable: ${config_filename}") - elseif(NOT "${GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "") - # Use CMake variable if set - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_UNIVERSAL_CONFIG_FILE}") - message(VERBOSE " Using custom config: ${GEMM_UNIVERSAL_CONFIG_FILE}") - else() - # Use default config for all layouts - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - message(VERBOSE " Using default config for layout ${layout}") - endif() - - # Check if config file exists - if(NOT EXISTS ${json_blob}) - message(FATAL_ERROR "Config file not found: ${json_blob}") - endif() - - # Determine number of workers for parallel generation - if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) - set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) - else() - # Use processor count but limit to avoid memory issues - cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) - math(EXPR num_workers "${num_cores}") - if(num_workers GREATER 8) - set(num_workers 8) - endif() - endif() - - # Generate individual kernel files using parallel version - message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") - message(VERBOSE " Working path: ${working_path}") - message(VERBOSE " Config file: ${json_blob}") - message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") - message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py") - - # Create working directory first - file(MAKE_DIRECTORY ${working_path}) - - message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${json_blob} - --gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL} - --list_kernels ") - - # Build optional args for instance builder - set(extra_list_args "") - if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "") - list(APPEND extra_list_args --max-instances ${GEMM_UNIVERSAL_MAX_INSTANCES}) - endif() - if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") - list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER}) - list(APPEND extra_list_args --manifest-path ${working_path}) - endif() - if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "") - list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED}) - endif() - - # First, just list the kernels (fast operation) - message(VERBOSE " Listing kernel configurations...") - execute_process( - COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${json_blob} - --gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL} - --list_kernels - ${extra_list_args} - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} - RESULT_VARIABLE ret - OUTPUT_VARIABLE list_output - ERROR_VARIABLE list_error - ) - - if(NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") - endif() - - # Read kernel count - if(EXISTS ${working_path}/gemm_universal_kernel_count.txt) - file(READ ${working_path}/gemm_universal_kernel_count.txt kernel_count) - string(STRIP "${kernel_count}" kernel_count) - message(VERBOSE " Found ${kernel_count} kernel configurations") - else() - message(FATAL_ERROR "Kernel count file not found") - endif() - - # Read kernel list and create targets - if(EXISTS ${working_path}/gemm_universal_kernel_list.txt) - file(STRINGS ${working_path}/gemm_universal_kernel_list.txt kernel_lines) - foreach(line IN LISTS kernel_lines) - # Parse line: kernel_name|tile_config|trait_combo - string(REPLACE "|" ";" parts "${line}") - list(GET parts 0 kernel_name) - list(GET parts 1 tile_config) - list(GET parts 2 trait_combo) - - # Create individual target - create_individual_gemm_universal_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") - endforeach() - else() - message(FATAL_ERROR "Kernel list file not found") - endif() -endfunction() - -# Main build logic - Only individual builds supported -message(VERBOSE "=== Starting Tile Engine GEMM Universal Configuration ===") -message(VERBOSE "GEMM_UNIVERSAL_DATATYPE: ${GEMM_UNIVERSAL_DATATYPE}") -message(VERBOSE "GEMM_UNIVERSAL_LAYOUT: ${GEMM_UNIVERSAL_LAYOUT}") -message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") - -# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 -set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "") -set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic") - -foreach(target IN LISTS SUPPORTED_GPU_TARGETS) - if(target IN_LIST DESIRED_TARGETS) - list(APPEND GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL ${target}) - message(VERBOSE " Adding GPU target: ${target}") - endif() -endforeach() - -# Skip build if no matching targets found -if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") -else() - message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}") - - # Enable parallel compilation optimizations - # Set up job pools for better parallel compilation control - set_property(GLOBAL PROPERTY JOB_POOLS - compile_heavy=4 # Limit heavy compilations to prevent OOM - compile_normal=16 # Allow more parallel normal compilations - ) - - # Enable compiler cache if available and explicitly requested - # Disabled by default due to permission issues in CI environments - if(ENABLE_CCACHE_GEMM_UNIVERSAL) - find_program(CCACHE_PROGRAM ccache) - if(CCACHE_PROGRAM) - set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) - message(VERBOSE "Using ccache for faster compilation") - else() - message(WARNING "ccache requested but not found") - endif() - else() - message(VERBOSE "ccache disabled for GEMM Universal ops (use -DENABLE_CCACHE_GEMM_UNIVERSAL=ON to enable)") - endif() - - # Create master collection targets - add_custom_target(benchmark_gemm_universal_all) - - # Create datatype collection targets - foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) - add_custom_target(benchmark_gemm_universal_${dt}) - endforeach() - - # Create layout collection targets - foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) - add_custom_target(benchmark_gemm_universal_${l}) - endforeach() - - # Create combined collection targets - foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) - foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) - add_custom_target(benchmark_gemm_universal_${dt}_${l}) - endforeach() - endforeach() - - # Create trait-based collection targets - # These are common trait components used across all GEMM Universal kernels - set(GEMM_UNIVERSAL_PIPELINES "mem;compv3;compv4") - set(GEMM_UNIVERSAL_EPILOGUES "default;cshuffle") - set(GEMM_UNIVERSAL_SCHEDULERS "intrawave;interwave") - - foreach(pipeline IN LISTS GEMM_UNIVERSAL_PIPELINES) - add_custom_target(benchmark_gemm_universal_${pipeline}_pipeline) - endforeach() - - foreach(epilogue IN LISTS GEMM_UNIVERSAL_EPILOGUES) - add_custom_target(benchmark_gemm_universal_${epilogue}_epilogue) - endforeach() - - foreach(scheduler IN LISTS GEMM_UNIVERSAL_SCHEDULERS) - add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler) - endforeach() - - # Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that - # sampling fires per-combo rather than being a single cap larger than any combo's - # feasible set (which would make sampling a no-op for most combos). - if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "") - list(LENGTH GEMM_UNIVERSAL_DATATYPE _gu_n_dt) - list(LENGTH GEMM_UNIVERSAL_LAYOUT _gu_n_lay) - math(EXPR _gu_n_combos "${_gu_n_dt} * ${_gu_n_lay}") - if(_gu_n_combos GREATER 0) - math(EXPR GEMM_UNIVERSAL_MAX_INSTANCES - "${GEMM_UNIVERSAL_MAX_INSTANCES} / ${_gu_n_combos}") - message(STATUS " gemm_universal: per-combo budget = ${GEMM_UNIVERSAL_MAX_INSTANCES} (${_gu_n_combos} combos)") - endif() - endif() - - # Build individual targets for each datatype/layout combination - foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) - foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) - build_individual_gemm_universal_targets(${dt} ${l}) - endforeach() - endforeach() -endif() diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp deleted file mode 100644 index 23338a6cd008..000000000000 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include -#include -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "gemm/gemm_benchmark.hpp" - -#if __clang_major__ >= 23 -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" -#endif -// Data types and Layouts are defined by the generated kernel headers -// No hardcoded type definitions here to avoid conflicts - -/// @brief Function to get the kernel output with reference implementation on CPU/GPU -void gemm_host_reference(int verify, - ck_tile::HostTensor& a_m_k, - ck_tile::HostTensor& b_k_n, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C) -{ - if(verify == 1) - { - c_m_n_host_result.SetZero(); - - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_result); - } - else if(verify == 2) - { - if constexpr(std::is_same_v) - { - // Restore input for B for gpu reference - b_k_n_dev_buf.ToDevice(b_k_n.data()); - } - - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); - c_m_n_host_result.SetZero(); - c_m_n_gpu_buf_ref.SetZero(); - - ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); - BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); - CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); - - ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); - } -} -#if __clang_major__ >= 23 -#pragma clang diagnostic pop -#endif diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py deleted file mode 100755 index 73ba1261a849..000000000000 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -import os -import sys -import argparse -import time -import importlib.util - - -def _import_gemm_benchmark(): - """Import gemm benchmark from parent directory.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - parent_dir = os.path.dirname(current_dir) - - # Load the module dynamically - spec = importlib.util.spec_from_file_location( - "gemm_benchmark", - os.path.join(parent_dir, "gemm_benchmark.py"), - ) - gemm_benchmark_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(gemm_benchmark_module) - - return gemm_benchmark_module.GemmBenchmark - - -def _import_benchmark_utils(): - """Import benchmark utilities from commons directory.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - parent_dir = os.path.dirname(os.path.dirname(current_dir)) - - # Load the module dynamically - spec = importlib.util.spec_from_file_location( - "benchmark_utils", - os.path.join(parent_dir, "common", "benchmark_utils.py"), - ) - benchmark_utils = importlib.util.module_from_spec(spec) - spec.loader.exec_module(benchmark_utils) - - return benchmark_utils - - -GemmBenchmark = _import_gemm_benchmark() -benchmark_utils = _import_benchmark_utils() - - -class GemmUniversalBenchmark(GemmBenchmark): - def __init__(self, build_dir: str, verbose: bool = False): - super().__init__(build_dir, verbose, name="benchmark_gemm_universal_") - - -def main(): - parser = argparse.ArgumentParser( - description="Universal GEMM Kernel Benchmarking Tool" - ) - parser.add_argument( - "build_dir", help="Build directory containing kernel executables" - ) - parser.add_argument( - "--problem-sizes", - nargs="+", - default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"], - help="Problem sizes as M,N,K tuples", - ) - parser.add_argument( - "--split-k", nargs="+", type=int, default=[1], help="Split-K values to test" - ) - parser.add_argument("--verify", action="store_true", help="Enable verification") - parser.add_argument( - "--csv", - default="gemm_universal_benchmark_results.csv", - help="CSV output filename", - ) - parser.add_argument( - "--best", default="best_kernels.txt", help="Best kernels output filename" - ) - parser.add_argument("--verbose", action="store_true", help="Verbose output") - parser.add_argument( - "--warmup", - type=int, - default=50, - help="Number of warmup iterations (default: 50)", - ) - parser.add_argument( - "--repeat", - type=int, - default=100, - help="Number of benchmark iterations (default: 100)", - ) - parser.add_argument( - "--flush-cache", - action="store_true", - default=True, - help="Enable cache flushing (default: True)", - ) - parser.add_argument( - "--rotating-count", - type=int, - default=1000, - help="Number of iterations to rotate cache (default: 1000)", - ) - parser.add_argument("--json", help="JSON output filename (optional)") - - args = parser.parse_args() - - # Parse problem sizes - problem_sizes = [] - for size_str in args.problem_sizes: - try: - m, n, k = map(int, size_str.split(",")) - problem_sizes.append((m, n, k)) - except ValueError: - print(f"Invalid problem size: {size_str}") - return 1 - - # Create benchmark instance - benchmark = GemmUniversalBenchmark(args.build_dir, verbose=args.verbose) - - # Run benchmark sweep - print("Starting Universal GEMM kernel benchmark sweep...") - start_time = time.time() - - best_kernels = benchmark.benchmark_sweep( - problem_sizes=problem_sizes, - split_k_values=args.split_k, - verify=args.verify, - warmup=args.warmup, - repeat=args.repeat, - flush_cache=args.flush_cache, - rotating_count=args.rotating_count, - ) - - elapsed_time = time.time() - start_time - print(f"\nBenchmark completed in {elapsed_time:.2f} seconds") - - # Export results - benchmark_utils.export_csv(benchmark.results, args.csv) - benchmark_utils.export_best_kernels(best_kernels, args.best) - - # Export JSON if requested - if args.json: - benchmark_utils.export_json(benchmark.results, args.json, best_kernels) - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp deleted file mode 100644 index 9e73077e2895..000000000000 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "gemm/gemm_common.hpp" -#include "gemm_universal_profiler.hpp" - -// The kernel header is included via the compile command line with -include flag -// It defines SelectedKernel struct and KERNEL_NAME - -void benchmark_single(const ck_tile::ArgParser& arg_parser) -{ - // Use DataTypeTraits to get the actual type names from the generated header - // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = ck_tile::DataTypeTraits::name; - std::string dtype_b = ck_tile::DataTypeTraits::name; - std::string dtype_acc = ck_tile::DataTypeTraits::name; - std::string dtype_c = ck_tile::DataTypeTraits::name; - - // Layout names from the layout types - std::string layout_a = ALayout::name; - std::string layout_b = BLayout::name; - std::string layout_c = CLayout::name; - - // Create GemmProblem struct - GemmProblem gemm_problem{arg_parser.get_int("split_k"), - arg_parser.get_int("m"), - arg_parser.get_int("n"), - arg_parser.get_int("k"), - arg_parser.get_int("stride_a"), - arg_parser.get_int("stride_b"), - arg_parser.get_int("stride_c"), - dtype_a, - dtype_b, - dtype_acc, - dtype_c, - layout_a, - layout_b, - layout_c, - arg_parser.get_bool("structured_sparsity")}; - - // Create Settings struct - Settings setting{arg_parser.get_int("warmup"), - arg_parser.get_int("repeat"), - arg_parser.get_bool("timer"), - arg_parser.get_int("verify"), - arg_parser.get_int("init"), - arg_parser.get_bool("log"), - arg_parser.get_str("csv_filename"), - arg_parser.get_bool("flush_cache"), - arg_parser.get_int("rotating_count"), - arg_parser.get_bool("json_output")}; - - // Get the profiler instance - auto& profiler = UniversalGemmProfiler::GemmProfiler::instance(setting); - - try - { - // Create a lambda that wraps the kernel launch - auto kernel_func = [](const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& stream) { - return SelectedKernel::launch(args, stream); - }; - - // Benchmark the kernel - profiler.benchmark(gemm_problem, kernel_func); - - // Select best instance based on metric - profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); - } - catch(const std::exception& e) - { - std::cerr << "Benchmark failed: " << e.what() << std::endl; - } -} - -int main(int argc, char* argv[]) -{ - try - { - auto [result, parser] = create_args(argc, argv); - if(!result) - return EXIT_FAILURE; - - benchmark_single(parser); - return 0; - } - catch(const std::exception& e) - { - std::cerr << "Error: " << e.what() << "\n"; - return EXIT_FAILURE; - } -} diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py deleted file mode 100644 index 0d13584ca065..000000000000 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -import os -import argparse -import importlib.util -import multiprocessing -import concurrent.futures - - -def _import_gemm_kernel_builder(): - """Import validation utilities from commons directory.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - parent_dir = os.path.dirname(current_dir) - - # Load the module dynamically - spec = importlib.util.spec_from_file_location( - "gemm_instance_builder", - os.path.join(parent_dir, "gemm_instance_builder.py"), - ) - gemm_builder_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(gemm_builder_module) - - return gemm_builder_module.GemmKernelBuilder - - -GemmKernelBuilder = _import_gemm_kernel_builder() - - -class GemmUniversalKernelBuilder(GemmKernelBuilder): - def __init__( - self, - kernel_name_prefix, - working_path, - gpu_target, - datatype, - layout, - config_json=None, - max_instances=None, - seed=None, - tier=None, - manifest_path=None, - ): - super().__init__( - kernel_name_prefix, - working_path, - gpu_target, - datatype, - layout, - config_json, - max_instances=max_instances, - seed=seed, - tier=tier, - manifest_path=manifest_path, - ) - - def _generate_all_individual(self, num_workers=None): - """Generate individual kernel files for separate compilation with parallel processing""" - if num_workers is None: - num_workers = min( - multiprocessing.cpu_count(), 8 - ) # Limit to avoid memory issues - - tile_configs = self._get_tile_configs() - trait_combos = self._generate_trait_combinations() - - # Prepare work items for parallel processing - work_items = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - work_items.append( - ( - tile_config, - trait_combo, - self.kernel_name_prefix, - self.working_path, - self.gpu_target, - self.datatype, - self.layout, - self.config_json, - ) - ) - - # Apply RFC-compliant sampling (Sobol + LHS + maximin) - if self.max_instances is not None and len(work_items) > self.max_instances: - kernel_dicts = [ - {"tile_config": item[0], "trait_combo": item[1], "_work_item": item} - for item in work_items - ] - sampled = self._apply_sampling(kernel_dicts) - work_items = [k["_work_item"] for k in sampled] - - print( - f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." - ) - print(f" Tile configs: {len(tile_configs)}") - print(f" Trait combinations: {len(trait_combos)}") - print(f" Total kernels: {len(work_items)}") - - # Show first few work items for debugging - if work_items: - print(" First work item example:") - tile_config, trait_combo = work_items[0][:2] - print(f" Tile config: {tile_config}") - print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits - - # Process work items in parallel - kernel_list = [] - completed = 0 - - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_workers - ) as executor: - # Submit all work items - print(f" Submitting {len(work_items)} tasks to executor...") - future_to_item = { - executor.submit(_generate_single_kernel_individual, item): item - for item in work_items - } - print(" All tasks submitted, waiting for completion...") - - # Collect results with progress reporting - for future in concurrent.futures.as_completed(future_to_item): - completed += 1 - if completed % 100 == 0 or completed == len(work_items): - print( - f" Progress: {completed}/{len(work_items)} kernels generated" - ) - try: - result = future.result() - if result: - kernel_list.append(result) - except Exception as exc: - item = future_to_item[future] - print(f"Kernel generation failed for {item}: {exc}") - - # Sort kernel list for consistent ordering - kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name - - # Generate CMake include file for individual targets - self._generate_cmake_individual_targets(kernel_list) - - print( - f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" - ) - - -def _generate_single_kernel_individual(work_item): - """Worker function to generate a single individual kernel file""" - ( - tile_config, - trait_combo, - kernel_name_prefix, - working_path, - gpu_target, - datatype, - layout, - config_json, - ) = work_item - - # Create a temporary builder instance for this worker - builder = GemmUniversalKernelBuilder( - kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json - ) - - try: - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo - ) - - # Create simplified filename without the "gemm_universal_" prefix - # Remove "gemm_universal_" from the beginning of kernel_name for the filename - simplified_name = kernel_name - if simplified_name.startswith("gemm_universal_"): - simplified_name = simplified_name[ - len(kernel_name_prefix) + 1 : - ] # Remove "gemm_universal" prefix - - # Write individual header file - header_file = working_path / f"gemm_universal_single_{simplified_name}.hpp" - with open(header_file, "w") as f: - f.write(instance_code) - - return (kernel_name, trait_combo, tile_config) - except Exception as e: - print(f"Error generating individual kernel: {e}") - return None - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Universal kernel instance builder with parallel support" - ) - parser.add_argument("--working_path", required=True, help="Working directory path") - parser.add_argument( - "--gpu_target", - required=True, - help="GPU target architecture", - ) - parser.add_argument( - "--datatype", - required=True, - choices=["fp16", "fp8", "bf16", "bf8"], - help="Data type", - ) - parser.add_argument( - "--layout", - required=True, - choices=["rcr", "rrr", "ccr", "crr"], - help="Matrix layout", - ) - parser.add_argument("--config_json", help="Configuration JSON file") - parser.add_argument( - "--num_workers", type=int, help="Number of parallel workers (default: auto)" - ) - parser.add_argument( - "--gen_all_individual", - action="store_true", - help="Generate individual kernel files", - ) - parser.add_argument( - "--gen_single", action="store_true", help="Generate a single kernel file" - ) - parser.add_argument("--kernel_name", help="Kernel name for single generation") - parser.add_argument( - "--tile_config", help="Tile configuration string for single generation" - ) - parser.add_argument( - "--trait_combo", help="Trait combination string for single generation" - ) - parser.add_argument( - "--list_kernels", - action="store_true", - help="List kernel configurations without generating files", - ) - parser.add_argument( - "--max-instances", - type=int, - default=None, - help="Cap on number of kernel instances per (dtype, layout) combo", - ) - parser.add_argument( - "--seed", - type=int, - default=None, - help="RNG seed for deterministic sampling; if omitted, derived from today's date", - ) - parser.add_argument( - "--tier", - default=None, - help="Sampling tier (daily/weekly)", - ) - parser.add_argument( - "--manifest-path", - default=None, - help="Directory for chosen_instances.json", - ) - - args = parser.parse_args() - - assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( - f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" - ) - - layout_parts = args.layout.lower() - assert len(layout_parts) == 3, ( - f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" - ) - assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( - f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" - ) - assert layout_parts[2] == "r", ( - f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" - ) - - kernel_name_prefix = "gemm_universal" - builder = GemmUniversalKernelBuilder( - kernel_name_prefix, - args.working_path, - args.gpu_target, - args.datatype, - args.layout, - args.config_json, - max_instances=args.max_instances, - seed=args.seed, - tier=args.tier, - manifest_path=args.manifest_path, - ) - - if args.list_kernels: - builder._list_kernels() - elif args.gen_single: - # Generate a single kernel file input validation - if not args.kernel_name or not args.tile_config or not args.trait_combo: - parser.error( - "--gen_single requires --kernel_name, --tile_config, and --trait_combo" - ) - - # Parse tile config - tile_parts = args.tile_config.split("_") - tile_dims = tile_parts[0].split("x") - warp_dims = tile_parts[1].split("x") - warp_tile_dims = tile_parts[2].split("x") - - tile_config = { - "tile_m": int(tile_dims[0]), - "tile_n": int(tile_dims[1]), - "tile_k": int(tile_dims[2]), - "warp_m": int(warp_dims[0]), - "warp_n": int(warp_dims[1]), - "warp_k": int(warp_dims[2]), - "warp_tile_m": int(warp_tile_dims[0]), - "warp_tile_n": int(warp_tile_dims[1]), - "warp_tile_k": int(warp_tile_dims[2]), - } - - # Parse trait combo - trait_parts = args.trait_combo.split("_") - trait_combo = ( - trait_parts[0], # pipeline - trait_parts[1], # epilogue - trait_parts[2], # scheduler - trait_parts[3] == "True", # pad_m - trait_parts[4] == "True", # pad_n - trait_parts[5] == "True", # pad_k - trait_parts[6] == "True", # persistent - ) - - # Generate the kernel - builder._generate_kernel_instance( - tile_config, - trait_combo, - ) - elif args.gen_all_individual: - # Generate all individual kernel files - builder._generate_all_individual(args.num_workers) - else: - parser.error( - "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" - ) - - -if __name__ == "__main__": - main() diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp deleted file mode 100644 index 6eb4266aae88..000000000000 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include -#include - -#include "ck_tile/host/device_prop.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "gemm/gemm_benchmark.hpp" -#include "gemm/gemm_profiler.hpp" -#include "gemm_universal_benchmark.hpp" - -class UniversalGemmProfiler - : public GemmProfiler -{ - public: - using BaseGemm = GemmProfiler; - using BaseGemm::benchmark; - - UniversalGemmProfiler(Settings setting) - : GemmProfiler(setting) - { - } - - void benchmark(GemmProblem& gemm_problem, - std::vector( - ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) override - { - const ALayout layout_a = ALayout{}; - const BLayout layout_b = BLayout{}; - const CLayout layout_c = CLayout{}; - - gemm_problem.stride_a_ = ck_tile::get_default_stride( - gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); - gemm_problem.stride_b_ = ck_tile::get_default_stride( - gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); - gemm_problem.stride_c_ = ck_tile::get_default_stride( - gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); - - ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( - gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); - ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( - gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); - - if(setting_.init_method == 0) - { - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); - } - else if(setting_.init_method == 1) - { - ck_tile::FillMonotonicSeq{}(a_m_k); - ck_tile::FillMonotonicSeq{}(b_k_n); - } - else if(setting_.init_method == 2) - { - ck_tile::FillConstant{static_cast(1)}(a_m_k); - ck_tile::FillConstant{static_cast(1)}(b_k_n); - } - else - { - a_m_k.SetZero(); - b_k_n.SetZero(); - } - - if(gemm_problem.structured_sparsity_) - { - ck_tile::AdjustToStructuredSparsity{}(a_m_k); - } - - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - - if constexpr(std::is_same_v) - { - // Permute vector pk_i4x4 data for device implementation - ck_tile::HostTensor b_k_n_dev = b_k_n; - // permute_tensor_b(b_k_n_dev); - ck_tile::permute_vectors_i4x4_b(b_k_n_dev); - b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); - } - else - { - b_k_n_dev_buf.ToDevice(b_k_n.data()); - } - - a_m_k_dev_buf.ToDevice(a_m_k.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - - ck_tile::GemmHostArgs gemm_args = { - a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - gemm_problem.split_k_, - gemm_problem.m_, - gemm_problem.n_, - gemm_problem.k_, - gemm_problem.stride_a_, - gemm_problem.stride_b_, - gemm_problem.stride_c_, - }; - - ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( - gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); - - if(setting_.verify) - { - gemm_host_reference(setting_.verify, - a_m_k, - b_k_n, - c_m_n_host_result, - a_m_k_dev_buf, - b_k_n_dev_buf, - gemm_problem.m_, - gemm_problem.n_, - gemm_problem.k_, - gemm_problem.stride_a_, - gemm_problem.stride_b_, - gemm_problem.stride_c_); - } - - for(auto& callable : callables) - { - auto kernel_run_result = callable(gemm_args, - ck_tile::stream_config{nullptr, - true, - setting_.log, - setting_.n_warmup, - setting_.n_repeat, - setting_.is_gpu_timer, - setting_.flush_cache, - setting_.rotating_count}); - process_result(gemm_problem, - c_m_n_dev_buf, - c_m_n_host_result, - c_m_n_dev_result, - kernel_run_result); - } - } -}; diff --git a/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py new file mode 100644 index 000000000000..a47d9a8e3108 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +"""Worker script for running GEMM kernels in an isolated subprocess. + +Mirrors grouped_conv's run_one_grouped_conv_kernel.py: +- Receives kernel config + problem via stdin as JSON +- Loads the .so library ONLY inside this subprocess +- Outputs timing results as JSON to stdout (one line per kernel, flushed) +- A GPU fault kills only this process; the parent driver can continue + +Input JSON format: + Single: {"so_path": "...", "problem": {"M":.., "N":.., "K":..}, "kernel_name": "..."} + Batch: {"items": [{"so_path": "...", "problem": {...}, "kernel_name": "..."}, ...]} + +Optional top-level keys ``verify`` (bool) and ``verify_tol`` (float) enable an +fp32 numpy reference check; when set, each OK result also carries ``verified`` +and ``max_rel``. + +Output JSON format (one line per kernel): + {"idx": 0, "ok": true, "ms": 0.123, "tflops": 456.7, "non_zero": 1, "kernel": "..."} + {"idx": 0, "ok": true, ..., "verified": true, "max_rel": 3.1e-4} # with --verify + {"idx": 1, "ok": false, "error": "...", "kernel": "..."} +""" + +import json +import os +import sys + +# Add dispatcher python paths from environment (os.pathsep-separated). +gemm_pypath = os.environ.get("GEMM_PYPATH", "") +if gemm_pypath: + for p in gemm_pypath.split(os.pathsep): + if p and p not in sys.path: + sys.path.insert(0, p) + +from gemm_utils import GemmProblem, GpuGemmRunner # noqa: E402 +import numpy as np # noqa: E402 + + +def _run_one(idx, so_path, prob_dict, kernel_name, verify=False, verify_tol=2e-2): + """Run a single kernel and emit its result as one JSON line. + + When ``verify`` is set, the kernel output is checked against an fp32 numpy + reference (``A @ B``) using the global relative metric + ``max|out - ref| / max|ref|``; the emitted ``verified`` field then reflects + correctness, not just liveness (``non_zero``). + """ + try: + problem = GemmProblem.from_dict(prob_dict) + + np.random.seed(42) + # Generate fp32 source; the runner encodes to the kernel's real dtype + # (fp16 or bf16) based on the compiled kernel name. + A = (np.random.randn(problem.M, problem.K) * 0.1).astype(np.float32) + B = (np.random.randn(problem.K, problem.N) * 0.1).astype(np.float32) + + # CRITICAL: load the library ONLY inside this subprocess. + runner = GpuGemmRunner(lib_path=so_path) + result = runner.run(A, B, problem) + + if result.success: + non_zero = ( + int(np.count_nonzero(result.output)) + if result.output is not None + else 0 + ) + out = { + "idx": idx, + "ok": True, + "ms": result.time_ms, + "tflops": result.tflops, + "non_zero": non_zero, + "kernel": kernel_name, + } + if verify: + ref = A.astype(np.float32) @ B.astype(np.float32) + got = result.output.astype(np.float32) + denom = float(np.max(np.abs(ref))) or 1.0 + max_rel = float(np.max(np.abs(got - ref)) / denom) + out["max_rel"] = max_rel + out["verified"] = bool(max_rel <= verify_tol) + print(json.dumps(out), flush=True) + else: + print( + json.dumps( + { + "idx": idx, + "ok": False, + "error": f"kernel returned status {result.status}", + "kernel": kernel_name, + } + ), + flush=True, + ) + + except Exception as e: + print( + json.dumps( + {"idx": idx, "ok": False, "error": str(e), "kernel": kernel_name} + ), + flush=True, + ) + + +def main(): + """Read JSON from stdin, run kernel(s), output results.""" + try: + d = json.loads(sys.stdin.buffer.read()) + except Exception as e: + print( + json.dumps({"idx": 0, "ok": False, "error": f"JSON parse error: {e}"}), + flush=True, + ) + sys.exit(1) + + verify = bool(d.get("verify", False)) + verify_tol = float(d.get("verify_tol", 2e-2)) + + if "items" in d: + for i, item in enumerate(d["items"]): + _run_one( + i, + item["so_path"], + item["problem"], + item.get("kernel_name", "unknown"), + verify=verify, + verify_tol=verify_tol, + ) + else: + _run_one( + 0, + d["so_path"], + d["problem"], + d.get("kernel_name", "unknown"), + verify=verify, + verify_tol=verify_tol, + ) + + +if __name__ == "__main__": + main() From 22f0f3ae8195caa70b823aed532fc20f7f9c96a0 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Wed, 17 Jun 2026 16:30:09 -0400 Subject: [PATCH 02/16] [CK_TILE] GEMM bridge: layout-aware supports() to match Old-TE parity The bridge dispatcher's tile-divisibility gate rejected any problem where M % TileM != 0 for every layout, returning status -2 ("No suitable kernel") at runtime even though the .so built fine. This wrongly excluded bf16 rcr/rrr kernels with a non-power-of-two TileM (e.g. 192) on standard shapes like 1024^3 -- cases Old-TE compiles, runs, and verifies as correct. Root cause: supports() was layout-blind, while the underlying ck_tile::GemmKernel::IsSupportedArgument only constrains a dimension when an operand whose inner axis is that dimension participates without padding: RowMajor A -> K, ColMajor A -> M RowMajor B -> N, ColMajor B -> K RowMajor C -> N, ColMajor C -> M So for rcr (RowMajor A & C) M is never gated, which is why Old-TE runs M=192 tiles on M-indivisible problems. Make supports() compute require_m/n/k from the kernel key's A/B/C layouts so it mirrors IsSupportedArgument exactly (also honoring k_batch in the K grain). Anything it now lets through is still validated by the kernel's own IsSupportedArgument inside launch(), so the bridge stays a strict functional equivalent of Old-TE. Applied to both generated_tile_backend.hpp (the GEMM .so path) and the sibling tile_backend.hpp. Validated on gfx942 (MI300X): 85 previously status-2 rcr/rrr bf16 192-tile .so now run at 1024^3 (Old-TE runs the same, verification correct); the 8 remaining rejects are tile N=192 cases that Old-TE also reports "Arguments not supported" at N=1024 -- parity preserved in both directions. --- .../backends/generated_tile_backend.hpp | 40 ++++++++++++++----- .../dispatcher/backends/tile_backend.hpp | 32 ++++++++++----- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp index 0b072d761742..c724977f9613 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -52,26 +52,46 @@ class GeneratedTileKernelInstance : public KernelInstance bool supports(const Problem& problem) const override { - // Check dimension divisibility if padding not enabled + // Tile-divisibility gate, mirroring ck_tile::GemmKernel::IsSupportedArgument + // exactly. A dimension only needs to be a multiple of its tile size when an + // operand whose contiguous (inner) axis is that dimension participates AND + // padding for it is disabled. This is layout-dependent: + // + // layout RowMajor A -> inner axis K | layout ColMajor A -> inner axis M + // layout RowMajor B -> inner axis N | layout ColMajor B -> inner axis K + // layout RowMajor C -> inner axis N | layout ColMajor C -> inner axis M + // + // The old check blindly required M % TileM == 0 for every layout, which + // wrongly rejected e.g. rcr kernels (RowMajor A & C never gate M) on + // M-indivisible problems that Old-TE runs fine. Anything this lets through + // is still validated by the kernel's own IsSupportedArgument inside launch(), + // so the bridge stays a strict functional equivalent of Old-TE. constexpr bool pad_m = SelectedKernel::kPadM; constexpr bool pad_n = SelectedKernel::kPadN; constexpr bool pad_k = SelectedKernel::kPadK; - if(pad_m && pad_n && pad_k) - { - return true; // Padding enabled - supports any size - } - - // Check divisibility constexpr int tile_m = SelectedKernel::TileM; constexpr int tile_n = SelectedKernel::TileN; constexpr int tile_k = SelectedKernel::TileK; - if(!pad_m && problem.M % tile_m != 0) + const auto is_row = [](LayoutTag l) { return l == LayoutTag::RowMajor; }; + const bool row_a = is_row(key_.signature.layout_a); + const bool row_b = is_row(key_.signature.layout_b); + const bool row_c = is_row(key_.signature.layout_c); + + // Which problem dimensions are actually constrained for this layout combo. + const bool require_m = (!row_a) || (!row_c); // ColMajor A or C gate M + const bool require_n = row_b || row_c; // RowMajor B or C gate N + const bool require_k = row_a || (!row_b); // RowMajor A or ColMajor B gate K + + const std::int64_t k_grain = + static_cast(tile_k) * (problem.k_batch > 0 ? problem.k_batch : 1); + + if(require_m && !pad_m && problem.M % tile_m != 0) return false; - if(!pad_n && problem.N % tile_n != 0) + if(require_n && !pad_n && problem.N % tile_n != 0) return false; - if(!pad_k && problem.K % tile_k != 0) + if(require_k && !pad_k && problem.K % k_grain != 0) return false; return true; diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp index a3a0b0468562..e709c00e153f 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp @@ -29,27 +29,37 @@ class TileKernelInstance : public KernelInstance bool supports(const Problem& problem) const override { - // Check dimension divisibility if padding not enabled + // Tile-divisibility gate, layout-aware to match + // ck_tile::GemmKernel::IsSupportedArgument (see generated_tile_backend.hpp + // for the full rationale). A dimension is only constrained when an operand + // whose inner axis is that dimension participates and its padding is off: + // RowMajor A->K, ColMajor A->M; RowMajor B->N, ColMajor B->K; + // RowMajor C->N, ColMajor C->M. constexpr bool pad_m = SelectedKernel::kPadM; constexpr bool pad_n = SelectedKernel::kPadN; constexpr bool pad_k = SelectedKernel::kPadK; - if(pad_m && pad_n && pad_k) - { - // Padding enabled - supports any size - return true; - } - - // Check divisibility constexpr int tile_m = SelectedKernel::TileM; constexpr int tile_n = SelectedKernel::TileN; constexpr int tile_k = SelectedKernel::TileK; - if(!pad_m && problem.M % tile_m != 0) + const auto is_row = [](LayoutTag l) { return l == LayoutTag::RowMajor; }; + const bool row_a = is_row(key_.signature.layout_a); + const bool row_b = is_row(key_.signature.layout_b); + const bool row_c = is_row(key_.signature.layout_c); + + const bool require_m = (!row_a) || (!row_c); + const bool require_n = row_b || row_c; + const bool require_k = row_a || (!row_b); + + const std::int64_t k_grain = + static_cast(tile_k) * (problem.k_batch > 0 ? problem.k_batch : 1); + + if(require_m && !pad_m && problem.M % tile_m != 0) return false; - if(!pad_n && problem.N % tile_n != 0) + if(require_n && !pad_n && problem.N % tile_n != 0) return false; - if(!pad_k && problem.K % tile_k != 0) + if(require_k && !pad_k && problem.K % k_grain != 0) return false; // Check shared memory budget if specified From 0bed6d07fba338c9c4b48bfebb9878966f8353d1 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Wed, 17 Jun 2026 18:22:32 -0400 Subject: [PATCH 03/16] [CK_TILE] GEMM bridge: derive key layout from kernel instead of hardcoding rcr dispatcher_initialize() in gemm_ctypes_lib.cpp hardcoded the KernelKey layout to rcr (RowMajor/ColMajor/RowMajor) for every kernel. Now that supports() is layout-aware, that wrong key layout makes the dispatcher reject valid problems: a crr kernel does not gate K (neither A=ColMajor nor B=RowMajor has K as its inner axis), but with a hardcoded rcr key supports() applies rcr's K-gate and returns status -2 for TileK=192 problems (e.g. crr 64x64x192 at 1024^3) that Old-TE compiles, runs, and verifies (~87 TFLOPS). Derive signature.layout_a/b/c from the force-included kernel's own ALayout/BLayout/CLayout types via std::is_same_v with tensor_layout::gemm::RowMajor. The key now matches the kernel, so the layout-aware gate is correct for all four layouts. Execution was already layout-correct (the kernel uses its own compile-time layouts); only the host-side selection metadata was wrong. Validated on gfx942 (MI300X): crr 64x64x192 now runs on the bridge (93 TFLOPS), restoring parity with Old-TE. --- .../bindings/ctypes/gemm_ctypes_lib.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp b/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp index 57b98a8df135..c415ae6c0cef 100644 --- a/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp +++ b/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include "ck_tile/dispatcher/dispatcher.hpp" #include "ck_tile/dispatcher/registry.hpp" @@ -119,9 +120,20 @@ int dispatcher_initialize() key.signature.dtype_b = DataType::FP16; key.signature.dtype_c = DataType::FP16; key.signature.dtype_acc = DataType::FP32; - key.signature.layout_a = LayoutTag::RowMajor; - key.signature.layout_b = LayoutTag::ColMajor; - key.signature.layout_c = LayoutTag::RowMajor; + // Derive A/B/C layouts from the force-included kernel's own layout types + // instead of hardcoding rcr. The dispatcher's supports() gate is layout-aware + // (it only constrains a dimension that an operand's inner axis maps to), so a + // wrong key layout makes it reject valid problems -- e.g. a crr kernel does not + // gate K, but with a hardcoded rcr key supports() would apply rcr's K-gate and + // reject TileK=192 problems that Old-TE runs. ALayout/BLayout/CLayout are the + // global aliases exported by the kernel header under CK_TILE_SINGLE_KERNEL_INCLUDE. + using RowMajorLayout = ck_tile::tensor_layout::gemm::RowMajor; + key.signature.layout_a = + std::is_same_v ? LayoutTag::RowMajor : LayoutTag::ColMajor; + key.signature.layout_b = + std::is_same_v ? LayoutTag::RowMajor : LayoutTag::ColMajor; + key.signature.layout_c = + std::is_same_v ? LayoutTag::RowMajor : LayoutTag::ColMajor; key.signature.transpose_a = false; key.signature.transpose_b = false; key.signature.grouped = false; From 005dd06b70f27634f612c92002690c7fc206daf9 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Wed, 17 Jun 2026 22:19:36 -0400 Subject: [PATCH 04/16] [CK_TILE] GEMM bridge: make same-harness A/B cover all layouts + bf16 The >=20% bridge-vs-old-TE perf gaps in the parity sweep are a harness artifact: the sweep timed the bridge in-process but timed old-TE via its separate standalone benchmark binary, which runs the byte-identical kernel at a lower sustained SCLK. Measured through one harness the gap is <1%. ab_same_harness.py removed that artifact but hardcoded the old-TE header dir to fp16/rcr. Derive it per stem as // so one run covers rcr/rrr/ccr/crr and fp16+bf16, add a --stems-file/--csv resume-aware sweep mode, and use the median (not max) per point. --- .../parity_diag/regression/ab_same_harness.py | 115 ++++++++++++++++-- 1 file changed, 103 insertions(+), 12 deletions(-) diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py index 04e89e84d08b..d8b94dc4509e 100644 --- a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py +++ b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py @@ -18,12 +18,21 @@ kernel through the SAME worker (run_one_gemm_kernel.py). Measured this way the gap collapses to ~1%, which is the honest result. +The old-TE generated-header directory is derived per stem as +``///`` (e.g. fp16/rcr, bf16/crr), so a single +run covers every dtype/layout. Set OLD_TE_GEN to pin one explicit leaf dir for +all stems (legacy behavior); set OLD_TE_GEN_BASE to relocate the base. + Usage: - python3 ab_same_harness.py # default kernel list + shapes - python3 ab_same_harness.py [...] + python3 ab_same_harness.py # default kernel list + shapes + python3 ab_same_harness.py [...] # explicit stems + python3 ab_same_harness.py --stems-file F [--csv OUT] # sweep a stems file """ +import argparse +import csv import json import os +import statistics import subprocess import sys from pathlib import Path @@ -36,13 +45,16 @@ STATIC = DISP / "build" / "libck_tile_dispatcher.a" BR_SO_DIR = DISP / "build" / "examples" WORKER = ROOT / "tile_engine/ops/gemm/run_one_gemm_kernel.py" -# old-TE generated single-kernel headers. Override with OLD_TE_GEN; the default -# points at a sibling develop-parity worktree under the rocm-libraries root. -OLD_GEN = Path(os.environ.get( - "OLD_TE_GEN", +# Base dir of old-TE generated single-kernel headers; the per-stem leaf +# (/) is appended in old_gen_dir(). Points at a sibling +# develop-parity worktree under the rocm-libraries root by default. +OLD_GEN_BASE = Path(os.environ.get( + "OLD_TE_GEN_BASE", str(ROOT.parents[1] / ".claude/worktrees/develop-parity" - "/projects/composablekernel/build/tile_engine/ops/gemm/gemm_universal/fp16/rcr"), + "/projects/composablekernel/build/tile_engine/ops/gemm/gemm_universal"), )) +# Legacy explicit override: when set, this exact leaf dir is used for ALL stems. +OLD_GEN_PIN = os.environ.get("OLD_TE_GEN") OUT = DISP / "parity_diag" / "regression" / "_ab_same_harness_build" ARCH = os.environ.get("GFX_ARCH", "gfx942") DEVICE = os.environ.get("PARITY_DEVICE", "0") @@ -60,9 +72,23 @@ PYPATH = os.pathsep.join([str(DISP / "python"), str(ROOT / "tile_engine/ops/gemm")]) +def old_gen_dir(stem: str) -> Path: + """Old-TE header dir for a stem: // (or the pinned dir). + + Stems are named ``__...`` (e.g. fp16_rcr_..., bf16_crr_...), + which is exactly the develop-parity gen-tree layout, so the leaf is derived + from the stem itself -- no per-layout hardcoding. + """ + if OLD_GEN_PIN: + return Path(OLD_GEN_PIN) + parts = stem.split("_") + dtype, layout = parts[0], parts[1] + return OLD_GEN_BASE / dtype / layout + + def build_old_so(stem: str) -> Path | None: """Compile old TE's generated kernel header into a bridge-loadable .so.""" - hdr = OLD_GEN / f"gemm_universal_single_{stem}.hpp" + hdr = old_gen_dir(stem) / f"gemm_universal_single_{stem}.hpp" if not hdr.exists(): return None OUT.mkdir(parents=True, exist_ok=True) @@ -86,6 +112,9 @@ def build_old_so(stem: str) -> Path | None: def meas(so: Path, M: int, N: int, K: int) -> float | None: + """Median TFLOPS over REPEATS worker calls (each call does its own + warmup=50/repeat=100 internally). Median, not max, to match the sweep + methodology and stay robust to the occasional clock-warmup outlier.""" if not so or not Path(so).exists(): return None payload = json.dumps({"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, @@ -93,7 +122,7 @@ def meas(so: Path, M: int, N: int, K: int) -> float | None: env = os.environ.copy() env["HIP_VISIBLE_DEVICES"] = DEVICE env["GEMM_PYPATH"] = PYPATH - best = None + samples = [] for _ in range(REPEATS): p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, env=env) @@ -103,12 +132,74 @@ def meas(so: Path, M: int, N: int, K: int) -> float | None: except json.JSONDecodeError: continue if d.get("ok"): - best = d["tflops"] if best is None else max(best, d["tflops"]) - return best + samples.append(d["tflops"]) + return statistics.median(samples) if samples else None + + +def pipeline_of(stem: str) -> str: + for p in ("compv3", "compv4", "mem"): + if f"_{p}_" in stem: + return p + return "other" def main(): - stems = sys.argv[1:] or DEFAULT_STEMS + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("stems", nargs="*", help="kernel stems to A/B") + ap.add_argument("--stems-file", help="file with one stem per line") + ap.add_argument("--csv", help="write results to CSV (resume-aware)") + args = ap.parse_args() + + stems = list(args.stems) + if args.stems_file: + stems += [l.strip() for l in Path(args.stems_file).read_text().splitlines() + if l.strip()] + stems = stems or DEFAULT_STEMS + + # CSV sweep mode: same columns as the (now-corrected) sweep, resume-aware. + if args.csv: + fields = ["stem", "pipeline", "dtype", "layout", "shape", + "bridge_tflops", "old_tflops", "gap_pct", "oldte_built"] + out = Path(args.csv) + done = set() + if out.exists(): + with open(out) as f: + for row in csv.DictReader(f): + done.add((row["stem"], row["shape"])) + mode = "a" if done else "w" + print(f"stems={len(stems)} shapes={len(SHAPES)} resume={len(done)} -> {out}", + flush=True) + with open(out, mode, newline="") as fh: + w = csv.DictWriter(fh, fieldnames=fields) + if mode == "w": + w.writeheader() + for stem in stems: + todo = [(M, N, K) for (M, N, K) in SHAPES + if (stem, f"{M}x{N}x{K}") not in done] + if not todo: + continue + parts = stem.split("_") + dtype, layout = parts[0], parts[1] + old_so = build_old_so(stem) + br_so = BR_SO_DIR / f"libgemm_{stem}.so" + for (M, N, K) in todo: + shape = f"{M}x{N}x{K}" + b = meas(br_so, M, N, K) + o = meas(old_so, M, N, K) if old_so else None + gap = (b - o) / o * 100 if (b and o) else float("nan") + w.writerow(dict( + stem=stem, pipeline=pipeline_of(stem), dtype=dtype, + layout=layout, shape=shape, + bridge_tflops=f"{b:.4f}" if b is not None else "nan", + old_tflops=f"{o:.4f}" if o is not None else "nan", + gap_pct=f"{gap:.4f}" if gap == gap else "nan", + oldte_built=str(old_so is not None))) + fh.flush() + print(f" done {stem[:60]}", flush=True) + print(f"DONE -> {out}", flush=True) + return + + # Pretty-print mode. print(f"{'shape':>14} {'bridge':>9} {'oldTE':>9} {'gap%':>7} kernel") for stem in stems: old_so = build_old_so(stem) From 929f9e3c88d508fcf8028d977ccd3aab1fa26e23 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Wed, 17 Jun 2026 22:22:51 -0400 Subject: [PATCH 05/16] [CK_TILE] GEMM bridge: speed up same-harness sweep for full runs For a full ~2000-stem sweep on a single GPU: batch all shapes into one worker call per side (5x fewer process startups), cache the compiled old-TE .so, and add a parallel --build-only pre-pass so hipcc compilation uses all CPU cores while GPU measurement stays serial. --- .../parity_diag/regression/ab_same_harness.py | 82 ++++++++++++++++++- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py index d8b94dc4509e..48098bb58fb2 100644 --- a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py +++ b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py @@ -87,13 +87,19 @@ def old_gen_dir(stem: str) -> Path: def build_old_so(stem: str) -> Path | None: - """Compile old TE's generated kernel header into a bridge-loadable .so.""" + """Compile old TE's generated kernel header into a bridge-loadable .so. + + Cached: if the .so already exists it is reused, so a parallel --build-only + pre-pass (CPU-bound hipcc) can be separated from the serial GPU measurement. + """ hdr = old_gen_dir(stem) / f"gemm_universal_single_{stem}.hpp" if not hdr.exists(): return None OUT.mkdir(parents=True, exist_ok=True) obj = OUT / f"{stem}.o" lib = OUT / f"libold_{stem}.so" + if lib.exists(): + return lib common = [ "-fPIC", "-O3", f"-I{DISP / 'include'}", f"-I{ROOT / 'include'}", f"-I{ROOT}", f"-I{GEN}", @@ -136,6 +142,41 @@ def meas(so: Path, M: int, N: int, K: int) -> float | None: return statistics.median(samples) if samples else None +def meas_all(so: Path) -> dict: + """Median TFLOPS per shape from REPEATS *batched* worker calls. + + One worker call measures ALL shapes (5x fewer python+numpy+CDLL startups + than per-shape meas()), which is the throughput lever for a full sweep on a + single GPU. Returns {shape_str: tflops|None}.""" + out = {f"{M}x{N}x{K}": None for (M, N, K) in SHAPES} + if not so or not Path(so).exists(): + return out + items = [{"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, + "kernel_name": "x"} for (M, N, K) in SHAPES] + payload = json.dumps({"items": items, "verify": False}) + env = os.environ.copy() + env["HIP_VISIBLE_DEVICES"] = DEVICE + env["GEMM_PYPATH"] = PYPATH + samples = {s: [] for s in out} + for _ in range(REPEATS): + p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + env=env, timeout=900) + for line in p.stdout.decode().splitlines(): + try: + d = json.loads(line) + except json.JSONDecodeError: + continue + idx = d.get("idx") + if isinstance(idx, int) and 0 <= idx < len(SHAPES) and d.get("ok"): + M, N, K = SHAPES[idx] + samples[f"{M}x{N}x{K}"].append(d["tflops"]) + for s, xs in samples.items(): + if xs: + out[s] = statistics.median(xs) + return out + + def pipeline_of(stem: str) -> str: for p in ("compv3", "compv4", "mem"): if f"_{p}_" in stem: @@ -148,6 +189,11 @@ def main(): ap.add_argument("stems", nargs="*", help="kernel stems to A/B") ap.add_argument("--stems-file", help="file with one stem per line") ap.add_argument("--csv", help="write results to CSV (resume-aware)") + ap.add_argument("--build-only", action="store_true", + help="parallel-compile old-TE .so for all stems, then exit " + "(CPU pre-pass; GPU measurement reuses the cache)") + ap.add_argument("--jobs", type=int, default=min(os.cpu_count() or 8, 16), + help="parallel compile jobs for --build-only") args = ap.parse_args() stems = list(args.stems) @@ -156,6 +202,33 @@ def main(): if l.strip()] stems = stems or DEFAULT_STEMS + # Parallel CPU pre-compile of every old-TE .so (no GPU touched). + if args.build_only: + from concurrent.futures import ProcessPoolExecutor, as_completed + ok = miss = fail = 0 + print(f"build-only: {len(stems)} stems, jobs={args.jobs}", flush=True) + with ProcessPoolExecutor(max_workers=args.jobs) as ex: + futs = {ex.submit(build_old_so, s): s for s in stems} + for i, fut in enumerate(as_completed(futs), 1): + try: + r = fut.result() + except Exception: + r = None + s = futs[fut] + if r is None: + # distinguish "no header" from "compile failed" + if (old_gen_dir(s) / f"gemm_universal_single_{s}.hpp").exists(): + fail += 1 + else: + miss += 1 + else: + ok += 1 + if i % 100 == 0: + print(f" [{i}/{len(stems)}] ok={ok} no_header={miss} fail={fail}", + flush=True) + print(f"build-only DONE: ok={ok} no_header={miss} fail={fail}", flush=True) + return + # CSV sweep mode: same columns as the (now-corrected) sweep, resume-aware. if args.csv: fields = ["stem", "pipeline", "dtype", "layout", "shape", @@ -182,10 +255,13 @@ def main(): dtype, layout = parts[0], parts[1] old_so = build_old_so(stem) br_so = BR_SO_DIR / f"libgemm_{stem}.so" + # Batched: one worker call per side covers all shapes. + bridge = meas_all(br_so) + old = meas_all(old_so) if old_so else {} for (M, N, K) in todo: shape = f"{M}x{N}x{K}" - b = meas(br_so, M, N, K) - o = meas(old_so, M, N, K) if old_so else None + b = bridge.get(shape) + o = old.get(shape) gap = (b - o) / o * 100 if (b and o) else float("nan") w.writerow(dict( stem=stem, pipeline=pipeline_of(stem), dtype=dtype, From e76c838c690faaf53c9852ad3ef6aa1359d13cb9 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Fri, 19 Jun 2026 05:22:11 -0400 Subject: [PATCH 06/16] [CK_TILE] GEMM bridge: fix A/B parity harness (fair flags + stale-.so guard) The bridge-vs-old-TE A/B reported phantom regressions from two MEASUREMENT bugs, not real codegen gaps: - ab_same_harness.py built the old-TE side WITHOUT the TE codegen flags the bridge (and real old-TE's own CMake) use, so -enable-post-misched defaulted back on and old-TE ran ~10-40% faster -> the bridge looked regressed when it is at parity. Now both sides build with identical flags. - ab_efficient_sweep.py measured whatever libgemm_.so existed with no freshness check, so 3-day-old binaries built from an obsolete codegen showed up as -78%/+703% gaps. Added a guard: skip any .so older than its generated header (treated as missing) instead of reporting a phantom gap. With both fixes the 41 former >15% outlier stems measure within +/-10% (median +0.01%); no bridge codegen regression exists. Note: a separate, deliberately UNCOMMITTED perf change in gemm_utils.py (gate -enable-post-misched=0 on persistent) gives non-persistent large tiles ~9-40%; held back pending a broader persistent-kernel no-regression sweep. --- .../regression/ab_efficient_sweep.py | 163 ++++++++++++++++++ .../parity_diag/regression/ab_same_harness.py | 13 ++ 2 files changed, 176 insertions(+) create mode 100644 projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py new file mode 100644 index 000000000000..c7a332e9b938 --- /dev/null +++ b/projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Efficient A/B sweep: bridge .so vs Old-TE binary, all layouts + fp16/bf16. + +Faster successor to run_alllayout_sweep.py: the bridge side batches all shapes +for a stem into ONE run_one_gemm_kernel.py worker call (one Python+numpy+CDLL +startup per stem instead of one per measurement). Old-TE binaries are run once +per shape; their internal warmup=50/repeat=100 already yields a stable median, +matching the prior methodology. + +- Bridge .so : main worktree dispatcher/build/examples (built from the FIXED source). +- Old-TE bin : develop-parity worktree build/bin (develop branch), per user instruction. + +Writes allresult_fp16_bf16.csv with resume support (keyed on stem,shape). + +CSV fields: stem,pipeline,dtype,layout,shape,bridge_tflops,old_tflops,gap_pct, + bridge_verified,oldte_built +""" +import csv, json, os, re, subprocess, sys, time +from pathlib import Path + +ROOT = Path("/home/AMD/muozturk/New_project/rocm-libraries/projects/composablekernel") +DISP = ROOT / "dispatcher" +WORKER = ROOT / "tile_engine/ops/gemm/run_one_gemm_kernel.py" +SO_DIR = DISP / "build" / "examples" +GEN_DIR = DISP / "build" / "generated_kernels" +OLD_BIN_DIR = Path( + "/home/AMD/muozturk/New_project/rocm-libraries/.claude/worktrees" + "/develop-parity/projects/composablekernel/build/bin" +) +REG = DISP / "parity_diag" / "regression" +STEMS_FILE = REG / "stems_selected.txt" +CSV_OUT = REG / "allresult_fp16_bf16.csv" + +PYPATH = os.pathsep.join([str(DISP / "python"), str(ROOT / "tile_engine/ops/gemm")]) +DEVICE = os.environ.get("PARITY_DEVICE", "0") + +SHAPES = [(512, 512, 512), (1024, 1024, 1024), (2048, 2048, 2048), + (1024, 512, 256), (4096, 4096, 4096)] + +FIELDS = ["stem", "pipeline", "dtype", "layout", "shape", + "bridge_tflops", "old_tflops", "gap_pct", + "bridge_verified", "oldte_built"] + +_TFLOPS_RE = re.compile(r'"tflops\(TFlops\)":\s*([0-9.]+)') + + +def pipeline_of(stem): + for p in ("compv3", "compv4", "mem"): + if f"_{p}_" in stem: + return p + return "other" + + +def base_env(): + env = os.environ.copy() + env["HIP_VISIBLE_DEVICES"] = DEVICE + env["GEMM_PYPATH"] = PYPATH + env["LD_LIBRARY_PATH"] = "/opt/rocm/lib:" + env.get("LD_LIBRARY_PATH", "") + return env + + +def run_bridge_all(stem): + """One batched worker call over all SHAPES. Returns {shape_str: tflops|None}.""" + so = SO_DIR / f"libgemm_{stem}.so" + out = {f"{M}x{N}x{K}": None for (M, N, K) in SHAPES} + if not so.exists(): + return out + # Staleness guard: a .so older than its generated header was built from an + # obsolete codegen and must NOT be measured -- doing so reports phantom + # regressions (the big 256-tile gaps in allresult_fp16_bf16_2.csv were all + # stale binaries that recovered to parity on rebuild). Treat stale as missing. + hdr = GEN_DIR / f"gemm_{stem}.hpp" + if hdr.exists() and so.stat().st_mtime < hdr.stat().st_mtime: + print(f" STALE .so (older than header), skipping: {stem}", file=sys.stderr, flush=True) + return out + items = [{"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, + "kernel_name": f"gemm_{stem}"} for (M, N, K) in SHAPES] + payload = json.dumps({"items": items, "verify": False}) + try: + p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + env=base_env(), timeout=900) + except subprocess.TimeoutExpired: + return out + for line in p.stdout.decode().strip().splitlines(): + try: + d = json.loads(line) + except json.JSONDecodeError: + continue + idx = d.get("idx") + if isinstance(idx, int) and 0 <= idx < len(SHAPES) and d.get("ok"): + M, N, K = SHAPES[idx] + out[f"{M}x{N}x{K}"] = d.get("tflops") + return out + + +def run_oldte(stem, M, N, K): + binp = OLD_BIN_DIR / f"benchmark_gemm_universal_{stem}" + if not binp.exists(): + return None + try: + p = subprocess.run([str(binp), f"-m={M}", f"-n={N}", f"-k={K}", + "-warmup=50", "-repeat=100"], + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + env=base_env(), timeout=300) + except subprocess.TimeoutExpired: + return None + m = _TFLOPS_RE.search(p.stdout.decode()) + return float(m.group(1)) if m else None + + +def main(): + stems = [l.strip() for l in STEMS_FILE.read_text().splitlines() if l.strip()] + total = len(stems) * len(SHAPES) + done = set() + if CSV_OUT.exists(): + with open(CSV_OUT) as f: + for row in csv.DictReader(f): + done.add((row["stem"], row["shape"])) + mode = "a" if done else "w" + print(f"stems={len(stems)} shapes={len(SHAPES)} total={total} resume={len(done)}", flush=True) + + t0 = time.time(); n = len(done) + with open(CSV_OUT, mode, newline="") as fh: + w = csv.DictWriter(fh, fieldnames=FIELDS) + if mode == "w": + w.writeheader() + for stem in stems: + shapes_todo = [(M, N, K) for (M, N, K) in SHAPES + if (stem, f"{M}x{N}x{K}") not in done] + if not shapes_todo: + continue + parts = stem.split("_") + dtype, layout = parts[0], parts[1] + pipeline = pipeline_of(stem) + oldte_built = (OLD_BIN_DIR / f"benchmark_gemm_universal_{stem}").exists() + + bridge = run_bridge_all(stem) + for (M, N, K) in shapes_todo: + shape = f"{M}x{N}x{K}" + bt = bridge.get(shape) + ot = run_oldte(stem, M, N, K) + if bt is not None and ot not in (None, 0): + gap = (bt - ot) / ot * 100.0 + else: + gap = float("nan") + w.writerow(dict( + stem=stem, pipeline=pipeline, dtype=dtype, layout=layout, shape=shape, + bridge_tflops=f"{bt:.4f}" if bt is not None else "nan", + old_tflops=f"{ot:.4f}" if ot is not None else "nan", + gap_pct=f"{gap:.4f}" if gap == gap else "nan", + bridge_verified="None", oldte_built=str(oldte_built))) + fh.flush() + n += 1 + el = time.time() - t0 + rate = (n - len(done)) / el if el > 0 else 0 + eta = (total - n) / rate / 3600 if rate > 0 else 0 + print(f"[{n}/{total}] {stem[:48]:48} rate={rate:.1f}/s ETA={eta:.1f}h", flush=True) + print(f"DONE rows={n} -> {CSV_OUT}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py index 48098bb58fb2..83e8253b1cec 100644 --- a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py +++ b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py @@ -105,6 +105,19 @@ def build_old_so(stem: str) -> Path | None: f"-I{DISP / 'include'}", f"-I{ROOT / 'include'}", f"-I{ROOT}", f"-I{GEN}", "-DCK_TILE_SINGLE_KERNEL_INCLUDE", f"-include{hdr}", "-D__HIP_PLATFORM_AMD__", f"--offload-arch={ARCH}", f'-DGFX_ARCH="{ARCH}"', + # Match the bridge build's AMDGPU codegen flags (gemm_utils.py + # _build_compile_jobs / _TILE_ENGINE_CODEGEN_FLAGS), which are also what + # Tile Engine's own CMake passes. Without these the old-TE side is built + # with a *different* instruction schedule (notably -enable-post-misched + # defaults back on) and runs ~10-40% faster than real old-TE, making the + # bridge look regressed when it is actually at parity. Build BOTH sides + # identically so the A/B measures the kernel, not a flag asymmetry. + "-mllvm", "-enable-noalias-to-md-conversion=0", + "-mllvm", "--lsr-drop-solution=1", + "-mllvm", "-enable-post-misched=0", + "-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false", + "-fno-offload-uniform-block", "-Wno-undefined-func-template", "-Wno-float-equal", ] cc = subprocess.run(["/opt/rocm/bin/hipcc", "-c", *common, str(SRC), "-o", str(obj)], From a741d14e169f56375259771d65b86115204f1b1a Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Tue, 23 Jun 2026 08:56:57 -0700 Subject: [PATCH 07/16] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- projects/composablekernel/dispatcher/python/gemm_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/composablekernel/dispatcher/python/gemm_utils.py b/projects/composablekernel/dispatcher/python/gemm_utils.py index 242620f2a9eb..6dbbbb6c18d1 100644 --- a/projects/composablekernel/dispatcher/python/gemm_utils.py +++ b/projects/composablekernel/dispatcher/python/gemm_utils.py @@ -724,7 +724,8 @@ def expand_sweep( one GemmKernelConfig. Invalid combinations are dropped via the dispatcher's own validator, and duplicates (by .name) are collapsed. - For Phase 1 the signature is fixed to fp16 / rcr. + The signature is controlled by the `dtype` and `layout` arguments (defaults + to fp16 / rcr). """ with open(config_path) as f: cfg = json.load(f) From 1af77dbfa97035868d6411cdc026f4d6475aa0d9 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Tue, 23 Jun 2026 08:57:24 -0700 Subject: [PATCH 08/16] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../tile_engine/ops/gemm/gemm_full_benchmark.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py index 10f97328081f..0228d17903e0 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py @@ -127,10 +127,14 @@ def resolve_devices(spec): n = int(spec) if n <= 0: return detected - # Treat a bare integer as a device *count*: take the first n detected - # ids, falling back to a plain 0..n-1 range if detection under-reports. - # To target one specific device id, use the comma form (e.g. "5,"). - return detected[:n] if len(detected) >= n else [str(i) for i in range(n)] + # Treat a bare integer as a device *count*: take the first n detected ids. + # If the environment explicitly restricts visibility (HIP/CUDA_VISIBLE_DEVICES), + # do not invent additional ids beyond what's visible. + if len(detected) >= n: + return detected[:n] + if os.environ.get("HIP_VISIBLE_DEVICES") or os.environ.get("CUDA_VISIBLE_DEVICES"): + return detected + return [str(i) for i in range(n)] return [spec] From fa7aa1b71f4afe2fe84ab1c74dbb4d7b1b2f5587 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Tue, 23 Jun 2026 08:57:52 -0700 Subject: [PATCH 09/16] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../tile_engine/ops/gemm/run_one_gemm_kernel.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py index a47d9a8e3108..e1638297b82f 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py +++ b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py @@ -47,11 +47,17 @@ def _run_one(idx, so_path, prob_dict, kernel_name, verify=False, verify_tol=2e-2 try: problem = GemmProblem.from_dict(prob_dict) - np.random.seed(42) - # Generate fp32 source; the runner encodes to the kernel's real dtype - # (fp16 or bf16) based on the compiled kernel name. - A = (np.random.randn(problem.M, problem.K) * 0.1).astype(np.float32) - B = (np.random.randn(problem.K, problem.N) * 0.1).astype(np.float32) + # Cache host matrices per shape so batch mode doesn't regenerate huge inputs per kernel. + cache = getattr(_run_one, "_ab_cache", {}) + key = (problem.M, problem.N, problem.K) + if key not in cache: + rng = np.random.RandomState(42) + cache[key] = ( + (rng.randn(problem.M, problem.K) * 0.1).astype(np.float32), + (rng.randn(problem.K, problem.N) * 0.1).astype(np.float32), + ) + _run_one._ab_cache = cache + A, B = cache[key] # CRITICAL: load the library ONLY inside this subprocess. runner = GpuGemmRunner(lib_path=so_path) From 29c8cd5412d13f182bb5b2854fec93b485559ef0 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Tue, 23 Jun 2026 09:09:09 -0700 Subject: [PATCH 10/16] polish comments --- .../backends/generated_tile_backend.hpp | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp index c724977f9613..ff354f5523b2 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -123,20 +123,7 @@ class GeneratedTileKernelInstance : public KernelInstance problem.N // stride_E/C (row-major C: stride = N) ); - // Benchmark parameters. Defaults mirror old Tile Engine's - // gemm_common.hpp (warmup=50, repeat=100, flush_cache=true, - // rotating_count=1000), and a generous warmup keeps the GPU clock - // ramped. NOTE: matching these knobs does NOT by itself make - // bridge-vs-old-TE numbers comparable -- the byte-identical kernel - // measures ~18-20% faster here than through old TE's *standalone - // benchmark binary* at e.g. 1024^3/compv4, purely because that - // separate process runs the kernel at a lower sustained SCLK (+ more - // memory-stall cycles), not because of any bench knob, compiler, or - // kernel difference (rocprof-confirmed). For an honest A/B, measure - // BOTH kernels through the SAME harness (build the old-TE kernel into a - // .so and run it via run_one_gemm_kernel.py) -- the gap then collapses - // to ~1%. Each knob is env-overridable so a caller can match another - // harness without recompiling. + const bool bench = this->benchmarking_; ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = reinterpret_cast(stream); From 4d7ab76f6f1decf20c3cb8312c29a20beb0145b1 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Fri, 26 Jun 2026 15:01:46 -0400 Subject: [PATCH 11/16] [CK_TILE] gemm bridge: match Tile Engine GEMM codegen flags exactly The bridge compiles each kernel .so with a hand-maintained hipcc flag list (dispatcher/python/gemm_utils.py) that had drifted from Tile Engine's CMake flags, so the bridge .so and the TE benchmark were not compiled apples-to-apples: * MISSING -mllvm -amdgpu-coerce-illegal-types=1 (TE's CMakeLists.txt adds it when the compiler accepts it; the bridge build never did) * EXTRA -mllvm -enable-noalias-to-md-conversion=0 (not a TE GEMM flag; it only appears in standalone CK examples/tests, never the TE gemm path) Align the bridge's backend codegen flags with the exact set the TE gemm_universal benchmark TU is built with. The coerce flag is added through a cached hipcc probe that mirrors TE's check_cxx_compiler_flag, so the bridge stays matched to TE on every toolchain (present where TE has it, skipped where TE's CMake would skip it too). The generated kernel source was already identical between the two engines; this makes their compilation identical as well. --- .../dispatcher/python/gemm_utils.py | 77 ++++++++++++++++--- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/projects/composablekernel/dispatcher/python/gemm_utils.py b/projects/composablekernel/dispatcher/python/gemm_utils.py index 6dbbbb6c18d1..e02635e1c183 100644 --- a/projects/composablekernel/dispatcher/python/gemm_utils.py +++ b/projects/composablekernel/dispatcher/python/gemm_utils.py @@ -28,9 +28,14 @@ from __future__ import annotations import ctypes +import functools import itertools import json import multiprocessing +import os +import shutil +import subprocess +import tempfile from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path @@ -486,11 +491,23 @@ def run( # Build API: codegen + hipcc -> .so paths (no GPU) # ============================================================================ -# AMDGPU codegen flags Tile Engine passes to hipcc for GEMM kernels (see -# tile_engine/ops/gemm/gemm_universal CMake flags). They steer inlining and -# register allocation; omitting them changes occupancy and, because persistent -# kernels size their grid by occupancy, produces large perf gaps vs Tile Engine. -# Matching them keeps the bridge byte-for-byte performance-equivalent. +# AMDGPU codegen flags Tile Engine passes to hipcc for GEMM kernels. These MUST +# match, flag-for-flag, the set the Tile Engine gemm_universal benchmark TU is +# compiled with (projects/composablekernel/CMakeLists.txt) -- they steer inlining +# and register allocation, and because persistent kernels size their grid by +# occupancy, any mismatch produces large perf gaps vs Tile Engine and makes the +# parity comparison no longer apples-to-apples. +# +# Tile Engine's actual GEMM benchmark flags (verified from its compile_commands): +# -fno-offload-uniform-block +# -mllvm --lsr-drop-solution=1 +# -mllvm -enable-post-misched=0 +# -mllvm -amdgpu-early-inline-all=true +# -mllvm -amdgpu-function-calls=false +# -mllvm -amdgpu-coerce-illegal-types=1 (CMake adds this only when the +# compiler accepts it; see below) +# NOTE: -enable-noalias-to-md-conversion=0 is NOT a Tile Engine GEMM flag (it only +# appears in the standalone CK examples/tests), so it deliberately is NOT here. _TILE_ENGINE_CODEGEN_FLAGS = ( "-mllvm", "--lsr-drop-solution=1", "-mllvm", "-enable-post-misched=0", @@ -499,6 +516,43 @@ def run( "-fno-offload-uniform-block", ) +# Flags Tile Engine's CMake only adds when ``check_cxx_compiler_flag`` passes +# (newer -mllvm options that some clang builds reject). We mirror that probe so +# the bridge stays matched to Tile Engine on every toolchain: the flag is present +# exactly where TE would have it, and absent where TE's CMake would also skip it. +_PROBED_CODEGEN_FLAGS = ( + ("-mllvm", "-amdgpu-coerce-illegal-types=1"), +) + + +@functools.lru_cache(maxsize=None) +def _hipcc_accepts(flag_tuple: Tuple[str, ...]) -> bool: + """Mirror CMake check_cxx_compiler_flag: does hipcc compile a trivial TU with + these flags? Cached so the probe runs at most once per distinct flag set.""" + hipcc = os.environ.get("HIPCC") or shutil.which("hipcc") or "/opt/rocm/bin/hipcc" + try: + with tempfile.TemporaryDirectory() as d: + src = Path(d) / "probe.cpp" + src.write_text("int main(){}\n") + r = subprocess.run( + [hipcc, *flag_tuple, "-c", str(src), "-o", str(Path(d) / "probe.o")], + capture_output=True, timeout=120, + ) + return r.returncode == 0 + except Exception: + return False + + +@functools.lru_cache(maxsize=1) +def _tile_engine_codegen_flags() -> Tuple[str, ...]: + """Tile Engine's GEMM codegen flags plus any probe-gated flags the compiler + accepts -- the exact backend flag set the TE benchmark is built with.""" + flags = list(_TILE_ENGINE_CODEGEN_FLAGS) + for pair in _PROBED_CODEGEN_FLAGS: + if _hipcc_accepts(pair): + flags = list(pair) + flags + return tuple(flags) + def _build_compile_jobs( config: GemmKernelConfig, header: Path @@ -528,14 +582,13 @@ def _build_compile_jobs( "-D__HIP_PLATFORM_AMD__", f"--offload-arch={config.gfx_arch}", f'-DGFX_ARCH="{config.gfx_arch}"', - "-mllvm", - "-enable-noalias-to-md-conversion=0", - # Match Tile Engine's AMDGPU codegen flags. Without them the kernel is - # compiled with different inlining/register allocation, which changes - # occupancy; persistent kernels size their grid by occupancy - # (UniversalGemmKernel::MaxOccupancyGridSize = #CUs x occupancy), so the + # Match Tile Engine's AMDGPU codegen flags exactly (see + # _tile_engine_codegen_flags). Without them the kernel is compiled with + # different inlining/register allocation, which changes occupancy; + # persistent kernels size their grid by occupancy + # (UniversalGemmKernel::MaxOccupancyGridSize = #CUs x occupancy), so a # mismatch shows up as large perf gaps vs Tile Engine on persistent tiles. - *_TILE_ENGINE_CODEGEN_FLAGS, + *_tile_engine_codegen_flags(), "-Wno-undefined-func-template", "-Wno-float-equal", str(ctypes_source), From 33a45fb98e5aabdb476a2174d7a5b5095c0e4170 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Fri, 26 Jun 2026 17:52:57 -0400 Subject: [PATCH 12/16] [CK_TILE] gemm bridge: keep Old-TE (do not deprecate yet); drop parity_diag/regression Old-TE must remain until the dispatcher bridge implements every datatype Old-TE supports, so revert the Old-TE removal from the bridge commit and re-wire its build: * restore test/ck_tile/gemm_tile_engine/* (10 files) * restore tile_engine/ops/gemm/gemm_universal/* (6 files: benchmark / instance builder / profiler / single-bench / CMakeLists) * re-add `add_subdirectory(gemm_universal EXCLUDE_FROM_ALL)` in tile_engine/ops/gemm/CMakeLists.txt; restore test/ck_tile/CMakeLists.txt to the develop state (gemm_tile_engine entry kept commented, as in develop) Also drop the parity_diag/regression dev scripts that should not ship in the PR: * dispatcher/parity_diag/regression/ab_efficient_sweep.py * dispatcher/parity_diag/regression/ab_same_harness.py --- .../regression/ab_efficient_sweep.py | 163 -------- .../parity_diag/regression/ab_same_harness.py | 308 ---------------- .../test/ck_tile/CMakeLists.txt | 3 + .../ck_tile/gemm_tile_engine/CMakeLists.txt | 348 ++++++++++++++++++ .../test/ck_tile/gemm_tile_engine/README.md | 85 +++++ .../comprehensive_coverage_config.json | 37 ++ .../configs/large_datatype_config.json | 34 ++ .../configs/padding_coverage_config.json | 34 ++ .../configs/quick_coverage_config.json | 34 ++ .../configs/simple_test_config.json | 34 ++ .../configs/small_datatype_config.json | 35 ++ .../gemm_tile_engine/extract_test_params.py | 74 ++++ .../gemm_tile_engine/test_gemm_simple.cpp | 241 ++++++++++++ .../tile_engine/ops/gemm/CMakeLists.txt | 5 +- .../ops/gemm/gemm_universal/CMakeLists.txt | 338 +++++++++++++++++ .../gemm_universal_benchmark.hpp | 73 ++++ .../gemm_universal_benchmark.py | 149 ++++++++ .../gemm_universal_benchmark_single.cpp | 102 +++++ .../gemm_universal_instance_builder.py | 344 +++++++++++++++++ .../gemm_universal_profiler.hpp | 147 ++++++++ 20 files changed, 2115 insertions(+), 473 deletions(-) delete mode 100644 projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py delete mode 100644 projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py create mode 100644 projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp create mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt create mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp create mode 100755 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py create mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp create mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py create mode 100644 projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py deleted file mode 100644 index c7a332e9b938..000000000000 --- a/projects/composablekernel/dispatcher/parity_diag/regression/ab_efficient_sweep.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -"""Efficient A/B sweep: bridge .so vs Old-TE binary, all layouts + fp16/bf16. - -Faster successor to run_alllayout_sweep.py: the bridge side batches all shapes -for a stem into ONE run_one_gemm_kernel.py worker call (one Python+numpy+CDLL -startup per stem instead of one per measurement). Old-TE binaries are run once -per shape; their internal warmup=50/repeat=100 already yields a stable median, -matching the prior methodology. - -- Bridge .so : main worktree dispatcher/build/examples (built from the FIXED source). -- Old-TE bin : develop-parity worktree build/bin (develop branch), per user instruction. - -Writes allresult_fp16_bf16.csv with resume support (keyed on stem,shape). - -CSV fields: stem,pipeline,dtype,layout,shape,bridge_tflops,old_tflops,gap_pct, - bridge_verified,oldte_built -""" -import csv, json, os, re, subprocess, sys, time -from pathlib import Path - -ROOT = Path("/home/AMD/muozturk/New_project/rocm-libraries/projects/composablekernel") -DISP = ROOT / "dispatcher" -WORKER = ROOT / "tile_engine/ops/gemm/run_one_gemm_kernel.py" -SO_DIR = DISP / "build" / "examples" -GEN_DIR = DISP / "build" / "generated_kernels" -OLD_BIN_DIR = Path( - "/home/AMD/muozturk/New_project/rocm-libraries/.claude/worktrees" - "/develop-parity/projects/composablekernel/build/bin" -) -REG = DISP / "parity_diag" / "regression" -STEMS_FILE = REG / "stems_selected.txt" -CSV_OUT = REG / "allresult_fp16_bf16.csv" - -PYPATH = os.pathsep.join([str(DISP / "python"), str(ROOT / "tile_engine/ops/gemm")]) -DEVICE = os.environ.get("PARITY_DEVICE", "0") - -SHAPES = [(512, 512, 512), (1024, 1024, 1024), (2048, 2048, 2048), - (1024, 512, 256), (4096, 4096, 4096)] - -FIELDS = ["stem", "pipeline", "dtype", "layout", "shape", - "bridge_tflops", "old_tflops", "gap_pct", - "bridge_verified", "oldte_built"] - -_TFLOPS_RE = re.compile(r'"tflops\(TFlops\)":\s*([0-9.]+)') - - -def pipeline_of(stem): - for p in ("compv3", "compv4", "mem"): - if f"_{p}_" in stem: - return p - return "other" - - -def base_env(): - env = os.environ.copy() - env["HIP_VISIBLE_DEVICES"] = DEVICE - env["GEMM_PYPATH"] = PYPATH - env["LD_LIBRARY_PATH"] = "/opt/rocm/lib:" + env.get("LD_LIBRARY_PATH", "") - return env - - -def run_bridge_all(stem): - """One batched worker call over all SHAPES. Returns {shape_str: tflops|None}.""" - so = SO_DIR / f"libgemm_{stem}.so" - out = {f"{M}x{N}x{K}": None for (M, N, K) in SHAPES} - if not so.exists(): - return out - # Staleness guard: a .so older than its generated header was built from an - # obsolete codegen and must NOT be measured -- doing so reports phantom - # regressions (the big 256-tile gaps in allresult_fp16_bf16_2.csv were all - # stale binaries that recovered to parity on rebuild). Treat stale as missing. - hdr = GEN_DIR / f"gemm_{stem}.hpp" - if hdr.exists() and so.stat().st_mtime < hdr.stat().st_mtime: - print(f" STALE .so (older than header), skipping: {stem}", file=sys.stderr, flush=True) - return out - items = [{"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, - "kernel_name": f"gemm_{stem}"} for (M, N, K) in SHAPES] - payload = json.dumps({"items": items, "verify": False}) - try: - p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, - env=base_env(), timeout=900) - except subprocess.TimeoutExpired: - return out - for line in p.stdout.decode().strip().splitlines(): - try: - d = json.loads(line) - except json.JSONDecodeError: - continue - idx = d.get("idx") - if isinstance(idx, int) and 0 <= idx < len(SHAPES) and d.get("ok"): - M, N, K = SHAPES[idx] - out[f"{M}x{N}x{K}"] = d.get("tflops") - return out - - -def run_oldte(stem, M, N, K): - binp = OLD_BIN_DIR / f"benchmark_gemm_universal_{stem}" - if not binp.exists(): - return None - try: - p = subprocess.run([str(binp), f"-m={M}", f"-n={N}", f"-k={K}", - "-warmup=50", "-repeat=100"], - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, - env=base_env(), timeout=300) - except subprocess.TimeoutExpired: - return None - m = _TFLOPS_RE.search(p.stdout.decode()) - return float(m.group(1)) if m else None - - -def main(): - stems = [l.strip() for l in STEMS_FILE.read_text().splitlines() if l.strip()] - total = len(stems) * len(SHAPES) - done = set() - if CSV_OUT.exists(): - with open(CSV_OUT) as f: - for row in csv.DictReader(f): - done.add((row["stem"], row["shape"])) - mode = "a" if done else "w" - print(f"stems={len(stems)} shapes={len(SHAPES)} total={total} resume={len(done)}", flush=True) - - t0 = time.time(); n = len(done) - with open(CSV_OUT, mode, newline="") as fh: - w = csv.DictWriter(fh, fieldnames=FIELDS) - if mode == "w": - w.writeheader() - for stem in stems: - shapes_todo = [(M, N, K) for (M, N, K) in SHAPES - if (stem, f"{M}x{N}x{K}") not in done] - if not shapes_todo: - continue - parts = stem.split("_") - dtype, layout = parts[0], parts[1] - pipeline = pipeline_of(stem) - oldte_built = (OLD_BIN_DIR / f"benchmark_gemm_universal_{stem}").exists() - - bridge = run_bridge_all(stem) - for (M, N, K) in shapes_todo: - shape = f"{M}x{N}x{K}" - bt = bridge.get(shape) - ot = run_oldte(stem, M, N, K) - if bt is not None and ot not in (None, 0): - gap = (bt - ot) / ot * 100.0 - else: - gap = float("nan") - w.writerow(dict( - stem=stem, pipeline=pipeline, dtype=dtype, layout=layout, shape=shape, - bridge_tflops=f"{bt:.4f}" if bt is not None else "nan", - old_tflops=f"{ot:.4f}" if ot is not None else "nan", - gap_pct=f"{gap:.4f}" if gap == gap else "nan", - bridge_verified="None", oldte_built=str(oldte_built))) - fh.flush() - n += 1 - el = time.time() - t0 - rate = (n - len(done)) / el if el > 0 else 0 - eta = (total - n) / rate / 3600 if rate > 0 else 0 - print(f"[{n}/{total}] {stem[:48]:48} rate={rate:.1f}/s ETA={eta:.1f}h", flush=True) - print(f"DONE rows={n} -> {CSV_OUT}", flush=True) - - -if __name__ == "__main__": - main() diff --git a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py b/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py deleted file mode 100644 index 83e8253b1cec..000000000000 --- a/projects/composablekernel/dispatcher/parity_diag/regression/ab_same_harness.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 -"""Apples-to-apples GEMM A/B: bridge kernel vs old-TE kernel, ONE harness. - -Why this exists ---------------- -The earlier sweep (allsweep6144rcrfp16.py) compared the bridge's dispatcher -measurement against old TE's *standalone benchmark binary* -(benchmark_gemm_universal_). That comparison is NOT apples-to-apples: -the device kernel is byte-identical, yet old TE's standalone binary reports -~18-20% lower TFLOPS at e.g. 1024^3 / compv4. rocprof shows the identical -kernel genuinely runs longer in that process -- ~+8% cycles plus a lower -sustained SCLK -- a power/clock + execution-environment artifact of that -binary, NOT a bridge speedup, compiler difference, or kernel difference. -(See diagnose.md sec.4.) - -This harness removes the artifact: it builds the OLD-TE kernel into a .so from -old TE's own generated header and runs BOTH the bridge kernel and the old-TE -kernel through the SAME worker (run_one_gemm_kernel.py). Measured this way the -gap collapses to ~1%, which is the honest result. - -The old-TE generated-header directory is derived per stem as -``///`` (e.g. fp16/rcr, bf16/crr), so a single -run covers every dtype/layout. Set OLD_TE_GEN to pin one explicit leaf dir for -all stems (legacy behavior); set OLD_TE_GEN_BASE to relocate the base. - -Usage: - python3 ab_same_harness.py # default kernel list + shapes - python3 ab_same_harness.py [...] # explicit stems - python3 ab_same_harness.py --stems-file F [--csv OUT] # sweep a stems file -""" -import argparse -import csv -import json -import os -import statistics -import subprocess -import sys -from pathlib import Path - -# composablekernel root: .../composablekernel/dispatcher/parity_diag/regression/ -ROOT = Path(__file__).resolve().parents[3] -DISP = ROOT / "dispatcher" -GEN = DISP / "build" / "generated_kernels" -SRC = DISP / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" -STATIC = DISP / "build" / "libck_tile_dispatcher.a" -BR_SO_DIR = DISP / "build" / "examples" -WORKER = ROOT / "tile_engine/ops/gemm/run_one_gemm_kernel.py" -# Base dir of old-TE generated single-kernel headers; the per-stem leaf -# (/) is appended in old_gen_dir(). Points at a sibling -# develop-parity worktree under the rocm-libraries root by default. -OLD_GEN_BASE = Path(os.environ.get( - "OLD_TE_GEN_BASE", - str(ROOT.parents[1] / ".claude/worktrees/develop-parity" - "/projects/composablekernel/build/tile_engine/ops/gemm/gemm_universal"), -)) -# Legacy explicit override: when set, this exact leaf dir is used for ALL stems. -OLD_GEN_PIN = os.environ.get("OLD_TE_GEN") -OUT = DISP / "parity_diag" / "regression" / "_ab_same_harness_build" -ARCH = os.environ.get("GFX_ARCH", "gfx942") -DEVICE = os.environ.get("PARITY_DEVICE", "0") -REPEATS = int(os.environ.get("AB_REPEATS", "3")) - -SHAPES = [(512, 512, 512), (1024, 1024, 1024), (2048, 2048, 2048), - (1024, 512, 256), (4096, 4096, 4096)] - -DEFAULT_STEMS = [ - "fp16_rcr_compv4_default_intrawave_False_False_False_False_64x128x64_2x2x1_32x32x16", - "fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_64x128x64_1x4x1_32x32x16", - "fp16_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_4x1x1_32x32x16", -] - -PYPATH = os.pathsep.join([str(DISP / "python"), str(ROOT / "tile_engine/ops/gemm")]) - - -def old_gen_dir(stem: str) -> Path: - """Old-TE header dir for a stem: // (or the pinned dir). - - Stems are named ``__...`` (e.g. fp16_rcr_..., bf16_crr_...), - which is exactly the develop-parity gen-tree layout, so the leaf is derived - from the stem itself -- no per-layout hardcoding. - """ - if OLD_GEN_PIN: - return Path(OLD_GEN_PIN) - parts = stem.split("_") - dtype, layout = parts[0], parts[1] - return OLD_GEN_BASE / dtype / layout - - -def build_old_so(stem: str) -> Path | None: - """Compile old TE's generated kernel header into a bridge-loadable .so. - - Cached: if the .so already exists it is reused, so a parallel --build-only - pre-pass (CPU-bound hipcc) can be separated from the serial GPU measurement. - """ - hdr = old_gen_dir(stem) / f"gemm_universal_single_{stem}.hpp" - if not hdr.exists(): - return None - OUT.mkdir(parents=True, exist_ok=True) - obj = OUT / f"{stem}.o" - lib = OUT / f"libold_{stem}.so" - if lib.exists(): - return lib - common = [ - "-fPIC", "-O3", - f"-I{DISP / 'include'}", f"-I{ROOT / 'include'}", f"-I{ROOT}", f"-I{GEN}", - "-DCK_TILE_SINGLE_KERNEL_INCLUDE", f"-include{hdr}", "-D__HIP_PLATFORM_AMD__", - f"--offload-arch={ARCH}", f'-DGFX_ARCH="{ARCH}"', - # Match the bridge build's AMDGPU codegen flags (gemm_utils.py - # _build_compile_jobs / _TILE_ENGINE_CODEGEN_FLAGS), which are also what - # Tile Engine's own CMake passes. Without these the old-TE side is built - # with a *different* instruction schedule (notably -enable-post-misched - # defaults back on) and runs ~10-40% faster than real old-TE, making the - # bridge look regressed when it is actually at parity. Build BOTH sides - # identically so the A/B measures the kernel, not a flag asymmetry. - "-mllvm", "-enable-noalias-to-md-conversion=0", - "-mllvm", "--lsr-drop-solution=1", - "-mllvm", "-enable-post-misched=0", - "-mllvm", "-amdgpu-early-inline-all=true", - "-mllvm", "-amdgpu-function-calls=false", - "-fno-offload-uniform-block", - "-Wno-undefined-func-template", "-Wno-float-equal", - ] - cc = subprocess.run(["/opt/rocm/bin/hipcc", "-c", *common, str(SRC), "-o", str(obj)], - capture_output=True) - if cc.returncode != 0: - return None - ln = subprocess.run(["/opt/rocm/bin/hipcc", "-shared", "-fPIC", - f"--offload-arch={ARCH}", "--hip-link", - str(obj), str(STATIC), "-o", str(lib)], capture_output=True) - return lib if ln.returncode == 0 else None - - -def meas(so: Path, M: int, N: int, K: int) -> float | None: - """Median TFLOPS over REPEATS worker calls (each call does its own - warmup=50/repeat=100 internally). Median, not max, to match the sweep - methodology and stay robust to the occasional clock-warmup outlier.""" - if not so or not Path(so).exists(): - return None - payload = json.dumps({"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, - "kernel_name": "x"}) - env = os.environ.copy() - env["HIP_VISIBLE_DEVICES"] = DEVICE - env["GEMM_PYPATH"] = PYPATH - samples = [] - for _ in range(REPEATS): - p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, env=env) - for line in p.stdout.decode().splitlines(): - try: - d = json.loads(line) - except json.JSONDecodeError: - continue - if d.get("ok"): - samples.append(d["tflops"]) - return statistics.median(samples) if samples else None - - -def meas_all(so: Path) -> dict: - """Median TFLOPS per shape from REPEATS *batched* worker calls. - - One worker call measures ALL shapes (5x fewer python+numpy+CDLL startups - than per-shape meas()), which is the throughput lever for a full sweep on a - single GPU. Returns {shape_str: tflops|None}.""" - out = {f"{M}x{N}x{K}": None for (M, N, K) in SHAPES} - if not so or not Path(so).exists(): - return out - items = [{"so_path": str(so), "problem": {"M": M, "N": N, "K": K}, - "kernel_name": "x"} for (M, N, K) in SHAPES] - payload = json.dumps({"items": items, "verify": False}) - env = os.environ.copy() - env["HIP_VISIBLE_DEVICES"] = DEVICE - env["GEMM_PYPATH"] = PYPATH - samples = {s: [] for s in out} - for _ in range(REPEATS): - p = subprocess.run([sys.executable, str(WORKER)], input=payload.encode(), - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, - env=env, timeout=900) - for line in p.stdout.decode().splitlines(): - try: - d = json.loads(line) - except json.JSONDecodeError: - continue - idx = d.get("idx") - if isinstance(idx, int) and 0 <= idx < len(SHAPES) and d.get("ok"): - M, N, K = SHAPES[idx] - samples[f"{M}x{N}x{K}"].append(d["tflops"]) - for s, xs in samples.items(): - if xs: - out[s] = statistics.median(xs) - return out - - -def pipeline_of(stem: str) -> str: - for p in ("compv3", "compv4", "mem"): - if f"_{p}_" in stem: - return p - return "other" - - -def main(): - ap = argparse.ArgumentParser(description=__doc__) - ap.add_argument("stems", nargs="*", help="kernel stems to A/B") - ap.add_argument("--stems-file", help="file with one stem per line") - ap.add_argument("--csv", help="write results to CSV (resume-aware)") - ap.add_argument("--build-only", action="store_true", - help="parallel-compile old-TE .so for all stems, then exit " - "(CPU pre-pass; GPU measurement reuses the cache)") - ap.add_argument("--jobs", type=int, default=min(os.cpu_count() or 8, 16), - help="parallel compile jobs for --build-only") - args = ap.parse_args() - - stems = list(args.stems) - if args.stems_file: - stems += [l.strip() for l in Path(args.stems_file).read_text().splitlines() - if l.strip()] - stems = stems or DEFAULT_STEMS - - # Parallel CPU pre-compile of every old-TE .so (no GPU touched). - if args.build_only: - from concurrent.futures import ProcessPoolExecutor, as_completed - ok = miss = fail = 0 - print(f"build-only: {len(stems)} stems, jobs={args.jobs}", flush=True) - with ProcessPoolExecutor(max_workers=args.jobs) as ex: - futs = {ex.submit(build_old_so, s): s for s in stems} - for i, fut in enumerate(as_completed(futs), 1): - try: - r = fut.result() - except Exception: - r = None - s = futs[fut] - if r is None: - # distinguish "no header" from "compile failed" - if (old_gen_dir(s) / f"gemm_universal_single_{s}.hpp").exists(): - fail += 1 - else: - miss += 1 - else: - ok += 1 - if i % 100 == 0: - print(f" [{i}/{len(stems)}] ok={ok} no_header={miss} fail={fail}", - flush=True) - print(f"build-only DONE: ok={ok} no_header={miss} fail={fail}", flush=True) - return - - # CSV sweep mode: same columns as the (now-corrected) sweep, resume-aware. - if args.csv: - fields = ["stem", "pipeline", "dtype", "layout", "shape", - "bridge_tflops", "old_tflops", "gap_pct", "oldte_built"] - out = Path(args.csv) - done = set() - if out.exists(): - with open(out) as f: - for row in csv.DictReader(f): - done.add((row["stem"], row["shape"])) - mode = "a" if done else "w" - print(f"stems={len(stems)} shapes={len(SHAPES)} resume={len(done)} -> {out}", - flush=True) - with open(out, mode, newline="") as fh: - w = csv.DictWriter(fh, fieldnames=fields) - if mode == "w": - w.writeheader() - for stem in stems: - todo = [(M, N, K) for (M, N, K) in SHAPES - if (stem, f"{M}x{N}x{K}") not in done] - if not todo: - continue - parts = stem.split("_") - dtype, layout = parts[0], parts[1] - old_so = build_old_so(stem) - br_so = BR_SO_DIR / f"libgemm_{stem}.so" - # Batched: one worker call per side covers all shapes. - bridge = meas_all(br_so) - old = meas_all(old_so) if old_so else {} - for (M, N, K) in todo: - shape = f"{M}x{N}x{K}" - b = bridge.get(shape) - o = old.get(shape) - gap = (b - o) / o * 100 if (b and o) else float("nan") - w.writerow(dict( - stem=stem, pipeline=pipeline_of(stem), dtype=dtype, - layout=layout, shape=shape, - bridge_tflops=f"{b:.4f}" if b is not None else "nan", - old_tflops=f"{o:.4f}" if o is not None else "nan", - gap_pct=f"{gap:.4f}" if gap == gap else "nan", - oldte_built=str(old_so is not None))) - fh.flush() - print(f" done {stem[:60]}", flush=True) - print(f"DONE -> {out}", flush=True) - return - - # Pretty-print mode. - print(f"{'shape':>14} {'bridge':>9} {'oldTE':>9} {'gap%':>7} kernel") - for stem in stems: - old_so = build_old_so(stem) - br_so = BR_SO_DIR / f"libgemm_{stem}.so" - if old_so is None: - print(f" [skip: no old-TE header] {stem}") - continue - for (M, N, K) in SHAPES: - b = meas(br_so, M, N, K) - o = meas(old_so, M, N, K) - gap = (b - o) / o * 100 if (b and o) else float("nan") - print(f"{f'{M}x{N}x{K}':>14} {b or float('nan'):9.2f} " - f"{o or float('nan'):9.2f} {gap:7.2f} {stem[:40]}") - - -if __name__ == "__main__": - main() diff --git a/projects/composablekernel/test/ck_tile/CMakeLists.txt b/projects/composablekernel/test/ck_tile/CMakeLists.txt index d55d3f609e86..52552d8711ab 100644 --- a/projects/composablekernel/test/ck_tile/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/CMakeLists.txt @@ -71,6 +71,9 @@ if(BUILD_CK_TILE_FMHA_TESTS) add_subdirectory(fmha) endif() if(BUILD_CK_TILE_ENGINE_TESTS) +# TODO: The Universal GEMM tile engine test will be either removed +# or moved to the appropriate location in future work. +# add_subdirectory(gemm_tile_engine) add_subdirectory(pooling_tile_engine) endif() add_subdirectory(pooling) diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt b/projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt new file mode 100644 index 000000000000..374370f57076 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/CMakeLists.txt @@ -0,0 +1,348 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================ +# GEMM Tile Engine Unit Tests +# +# This CMake file creates unit tests for tile_engine generated GEMM kernels. +# It follows the exact same build patterns as tile_engine for consistency +# and reliability. Each kernel configuration gets its own test executable. +# ============================================================================ + +# Locate tile_engine GEMM scripts directory +set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm/gemm_universal") + +if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR}) + message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}") + return() +endif() + +# ============================================================================ +# create_individual_gemm_test_target +# +# Creates a single test executable for a specific kernel configuration. +# Mirrors tile_engine's create_individual_gemm_target function for consistency. +# +# Parameters: +# datatype - Data type (fp16, bf16, fp32, etc.) +# layout - Matrix layout (rcr, rrr, ccr, crr) +# config_name - Configuration file name without .json extension +# trait - Kernel trait combination string +# tile_config - Tile configuration parameters +# config_json - Full path to JSON configuration file +# ============================================================================ +function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json) + set(target_name "test_gemm_universal_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") + + # Generated header path (already created during cmake configuration) + set(test_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + set(test_params_header "${working_path}/test_params.hpp") + + # Verify header exists (should have been generated during cmake configuration) + if(NOT EXISTS ${test_header}) + message(WARNING "Generated header not found: ${test_header}") + return() + endif() + + # Verify test parameters header exists + if(NOT EXISTS ${test_params_header}) + message(WARNING "Test parameters header not found: ${test_params_header}") + return() + endif() + + + # Create GTest executable for this kernel configuration + add_gtest_executable(${target_name} + ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp + ) + + # Configure GPU architectures for HIP compilation + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS}) + + # Define preprocessor macros for generated header location and test parameters + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${test_header}" + GEMM_TEST_PARAMS_HPP="${test_params_header}" + ) + + # Include directories for headers and dependencies + target_include_directories(${target_name} PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR} # Root directory for tile_engine access + ${GTEST_INCLUDE_DIRS} + ) + + # Compiler options matching tile_engine requirements + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template # Suppress template warnings + -Wno-float-equal # Allow floating point comparisons + --offload-compress # Enable GPU code compression + -include ${test_header} # Auto-include generated header + ) + + # Add FP8 format definitions for proper data type interpretation + if(CK_USE_OCP_FP8) + target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() + + message(DEBUG " Created test target: ${target_name}") +endfunction() + +# ============================================================================ +# build_gemm_test_targets +# +# Builds all test targets for a specific datatype/layout/config combination. +# Uses tile_engine's two-step process: list kernels, then generate tests. +# +# Parameters: +# datatype - Data type (fp16, bf16, fp32, etc.) +# layout - Matrix layout (rcr, rrr, ccr, crr) +# config_name - Configuration file name without .json extension +# ============================================================================ +function(build_gemm_test_targets datatype layout config_name) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") + + # Locate and validate configuration file + set(config_filename "${config_name}.json") + set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") + + if(NOT EXISTS ${json_blob}) + message(WARNING "Test config file not found: ${json_blob}") + return() + endif() + + # Prepare build directory for this configuration + file(MAKE_DIRECTORY ${working_path}) + + # STEP 1: Discovery phase - list all valid kernel configurations + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels + --gpu_target "${GEMM_TEST_GPU_TARGETS}" + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}") + return() + endif() + + # Verify kernel list file was generated + if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) + message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") + return() + endif() + + message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}") + + # STEP 2a: Extract test parameters from config + set(test_params_file "${working_path}/test_params.hpp") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py + --config_file ${json_blob} + --output_file ${test_params_file} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE extract_ret + OUTPUT_VARIABLE extract_output + ERROR_VARIABLE extract_error + ) + + if(NOT extract_ret EQUAL 0) + message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}") + return() + endif() + + # STEP 2b: Header generation phase - generate headers using --gen_single + message(STATUS " Generating headers using --gen_single...") + + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(gen_count 0) + + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate header using --gen_single + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --gpu_target "${GEMM_TEST_GPU_TARGETS}" + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gen_single + --kernel_name "${kernel_name}" + --tile_config "${tile_config}" + --trait_combo "${trait_combo}" + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE gen_ret + OUTPUT_VARIABLE gen_output + ERROR_VARIABLE gen_error + ) + + if(NOT gen_ret EQUAL 0) + message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}") + else() + math(EXPR gen_count "${gen_count} + 1") + endif() + endif() + endforeach() + + message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}") + + # STEP 3: Target creation phase - create test targets + message(STATUS " Creating test targets...") + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(test_count 0) + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate test target for this kernel configuration + create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}") + math(EXPR test_count "${test_count} + 1") + endif() + endforeach() + message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}") +endfunction()# ============================================================================ +# MAIN EXECUTION - Test Target Generation +# ============================================================================ + +message(STATUS "=== Starting GEMM Tile Engine Test Configuration ===") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# GPU architecture filtering - only build tests for supported architectures +set(GEMM_TEST_GPU_TARGETS "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_TEST_GPU_TARGETS ${target}) + message(STATUS " Adding GPU target for tests: ${target}") + endif() +endforeach() + +# Early exit if no compatible GPU architectures are available +if(NOT GEMM_TEST_GPU_TARGETS) + message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + +message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF) + if(ENABLE_CCACHE_TESTS) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster test compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)") + endif() + +# ============================================================================ +# Test Configuration Matrix - Clean Focused Design +# ============================================================================ + +# All supported data types and layouts for comprehensive testing +# Note: fp64 not included (no MFMA hardware support) +set(TEST_DATATYPES "fp16;fp8;bf16;fp32") +set(TEST_LAYOUTS "rcr;rrr;ccr;crr") + +# ============================================================================ +# Test Target Generation - Datatype-Specific Categories +# ============================================================================ + +# 1. SMALL DATATYPES: Test optimized config for small data types (fp8, fp16, bf16) +# These data types can use larger warp tiles due to smaller memory footprint +set(SMALL_DATATYPE_CONFIG "small_datatype_config") +set(SMALL_DATATYPE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SMALL_DATATYPE_CONFIG}.json") +set(SMALL_DATATYPES "fp8;fp16;bf16") + +if(EXISTS ${SMALL_DATATYPE_CONFIG_FILE}) + message(STATUS "Processing small datatype config: ${SMALL_DATATYPE_CONFIG} (fp8, fp16, bf16)") + foreach(datatype IN LISTS SMALL_DATATYPES) + # fp8, fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) + foreach(layout IN LISTS TEST_LAYOUTS) + build_gemm_test_targets("${datatype}" "${layout}" "${SMALL_DATATYPE_CONFIG}") + endforeach() + endforeach() +else() + message(WARNING "Small datatype config file not found: ${SMALL_DATATYPE_CONFIG_FILE}") +endif() + +# 2. PADDING COVERAGE: Test padding combinations with fixed fp16/rcr configuration +# This focuses on padding behavior (pad_m, pad_n, pad_k) +set(PADDING_CONFIG "padding_coverage_config") +set(PADDING_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${PADDING_CONFIG}.json") + +if(EXISTS ${PADDING_CONFIG_FILE}) + message(STATUS "Processing padding config: ${PADDING_CONFIG} (fp16/rcr only)") + build_gemm_test_targets("fp16" "rcr" "${PADDING_CONFIG}") +else() + message(WARNING "Padding config file not found: ${PADDING_CONFIG_FILE}") +endif() + +# 3. COVERAGE LEVEL: Quick or comprehensive testing +# Quick: ~144 kernels with multiple tile sizes and trait combinations +# Comprehensive: Several thousand kernels with extensive tile sizes, warp configurations, and all trait combinations +set(COVERAGE_LEVEL "quick" CACHE STRING "Coverage level: quick or comprehensive") +set_property(CACHE COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive") + +if(COVERAGE_LEVEL STREQUAL "quick") + set(COVERAGE_CONFIG "quick_coverage_config") + set(COVERAGE_DESC "Quick - approximately 144 kernels with trait combinations") +elseif(COVERAGE_LEVEL STREQUAL "comprehensive") + set(COVERAGE_CONFIG "comprehensive_coverage_config") + set(COVERAGE_DESC "Comprehensive - several thousand kernels with extensive tile and trait coverage") +else() + message(FATAL_ERROR "Invalid COVERAGE_LEVEL: ${COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'") +endif() + +set(COVERAGE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COVERAGE_CONFIG}.json") + +if(EXISTS ${COVERAGE_CONFIG_FILE}) + message(STATUS "Processing coverage config: ${COVERAGE_LEVEL} - ${COVERAGE_DESC}") + build_gemm_test_targets("fp16" "rcr" "${COVERAGE_CONFIG}") +else() + message(WARNING "Coverage config file not found: ${COVERAGE_CONFIG_FILE}") +endif() +# ============================================================================ + + +message(STATUS "GEMM tile engine tests configured with datatype-specific design:") +message(STATUS " - Small datatypes: fp8/fp16/bf16 (all layouts)") +message(STATUS " - Padding coverage with fp16/rcr") +message(STATUS " - Coverage level: ${COVERAGE_LEVEL} (~144 kernels quick, several thousand comprehensive)") +message(STATUS " Use -DCOVERAGE_LEVEL=comprehensive for extensive testing") diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md b/projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md new file mode 100644 index 000000000000..87ce0c9fd05c --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/README.md @@ -0,0 +1,85 @@ +# GEMM Tile Engine Unit Tests + +## How It Works + +This unit test system integrates **tile_engine's kernel generation** into automated testing: + +1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels +2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine) +3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers +4. **Individual test executables**: Each kernel configuration becomes a separate test +5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine + +## Tile Engine Integration + +``` +JSON Config → tile_engine Python scripts → Generated Headers → Test Executables +``` + +- **`--list_kernels`**: Get available kernel configurations from JSON +- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration +- **`--gen_single`**: Generate individual kernel header for each configuration +- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations +- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching + +### Config-Specific Test Parameters + +Each test configuration can specify optimized problem sizes in its JSON file: +- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations +- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files +- **Build integration**: Each test target uses parameters appropriate for its kernel configuration +- **Optimized testing**: Different configs test different problem sizes that showcase their strengths + + +The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. + +## Test Configurations + +### 1. **Simple Test** (`simple_test_config.json`) +- **Purpose**: Basic functionality validation +- **Config**: 128x128x64, warp 2x2x1, warp_tile 16x16x16 +- **Traits**: compv3 + compv4 pipelines +- **Coverage**: ~2 kernels per datatype/layout + +### 2. **Small Datatype** (`small_datatype_config.json`) +- **Purpose**: Optimized for fp8/fp16/bf16 data types +- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 +- **Traits**: compv3 pipeline only +- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp8, fp16, bf16 + +### 3. **Padding Coverage** (`padding_coverage_config.json`) +- **Purpose**: Test padding behavior with all padding flags enabled +- **Config**: Fixed 64x64x32, warp 2x2x1, warp_tile 32x32x16 +- **Padding**: All enabled (pad_m=true, pad_n=true, pad_k=true) +- **Problem sizes**: Vector-aligned but not tile-aligned (104×104×56, 200×152×80, 152×200×64) +- **Coverage**: 1 kernel configuration testing padding with irregular sizes + +### 4. **Coverage Testing** (Quick or Comprehensive) +- **Purpose**: Comprehensive testing across tile sizes, warp configurations, and trait combinations +- **Quick** (`quick_coverage_config.json`): Approximately 144 kernels + - tile_m/n: [32, 64, 256], tile_k: [16, 32] + - warp config: 2×2×1, warp_tile 16×16×16 + - Traits: 3 pipelines × 2 epilogues × 2 schedulers (persistent=false only) + - Focused set testing trait combinations with multiple tile sizes +- **Comprehensive** (`comprehensive_coverage_config.json`): Several thousand kernels + - tile_m/n: [16-256 step 16] + - tile_k: [16, 32, 64] + - warp_m/n: [1, 2, 4], warp_tile_m/n: [16, 32], warp_tile_k: [16, 32] + - Traits: 3 pipelines × 2 epilogues × 2 schedulers × 2 persistent + - Extensive coverage across all tile sizes, warp configurations, and trait combinations + - Exact count varies based on validation filtering +- **Note**: Use CMake option `-DCOVERAGE_LEVEL=comprehensive` to enable comprehensive testing (default is quick) + +## Data Type Support +- ✅ **fp8, fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) +- ❌ **fp64**: Not supported (hardware MFMA limitation) +- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) + +## Test Result Behavior + +Tests automatically handle unsupported configurations through runtime validation: +- **PASSED**: Kernel executed correctly with results within error thresholds ✅ +- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️ +- **FAILED**: Actual error or incorrect computation results ❌ + +When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements. diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json new file mode 100644 index 000000000000..f2524e4a619d --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json @@ -0,0 +1,37 @@ +{ + "problem": { + "description": "Comprehensive coverage testing - extensive tile size coverage (16-256, step 16) with multiple warp configurations and all trait combinations. Several thousand kernels." + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 256, "split_k": 1}, + {"m": 1024, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 4} + ] + }, + "tile_config": { + "tile_m": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, + "tile_n": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, + "tile_k": {"values": [16, 32, 64]}, + "warp_m": {"values": [1, 2, 4]}, + "warp_n": {"values": [1, 2, 4]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16, 32]}, + "warp_tile_n": {"values": [16, 32]}, + "warp_tile_k": {"values": [8, 16, 32, 64, 128]} + }, + "trait_config": { + "pipeline": {"values": ["mem", "compv3", "compv4"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [true, false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json new file mode 100644 index 000000000000..e9fcb6fb8007 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Configuration optimized for large data types (fp32) with smaller warp tiles due to memory constraints" + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 192, "split_k": 1}, + {"m": 256, "n": 384, "k": 192, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [256]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json new file mode 100644 index 000000000000..33bada839de5 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Padding coverage testing - fixed config with fp16/rcr, varying only padding combinations" + }, + "test_params": { + "problem_sizes": [ + {"m": 104, "n": 104, "k": 56, "split_k": 1}, + {"m": 200, "n": 152, "k": 80, "split_k": 1}, + {"m": 152, "n": 200, "k": 64, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [64]}, + "tile_n": {"values": [64]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [true]}, + "pad_n": {"values": [true]}, + "pad_k": {"values": [true]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json new file mode 100644 index 000000000000..dcc6e99aee5a --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Quick coverage testing - tests multiple tile sizes with all trait combinations (pipelines, epilogues, schedulers). Approximately 144 kernels." + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 4} + ] + }, + "tile_config": { + "tile_m": {"values": [32, 64, 256]}, + "tile_n": {"values": [32, 64, 256]}, + "tile_k": {"values": [16, 32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["mem", "compv3", "compv4"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json new file mode 100644 index 000000000000..498ef9fa33a1 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Basic functionality validation with moderate problem sizes" + }, + "test_params": { + "problem_sizes": [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 256, "split_k": 1}, + {"m": 256, "n": 512, "k": 256, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [64]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3", "compv4"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json new file mode 100644 index 000000000000..d0d9f99a0cc7 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json @@ -0,0 +1,35 @@ +{ + "problem": { + "description": "Configuration optimized for small data types (fp8, fp16, bf16) with larger warp tiles" + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 256, "split_k": 1}, + {"m": 1024, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py b/projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py new file mode 100644 index 000000000000..48ec8dba8352 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/extract_test_params.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + + +import json +import argparse +import os +from pathlib import Path + + +def extract_test_params(config_file, output_file): + """Extract test parameters from config JSON and write to output file""" + + # Read config file + with open(config_file, "r") as f: + config = json.load(f) + + # Extract test parameters + test_params = [] + if "test_params" in config and "problem_sizes" in config["test_params"]: + test_params = config["test_params"]["problem_sizes"] + else: + # Default test parameters if none specified + test_params = [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 256, "n": 256, "k": 1024, "split_k": 1}, + {"m": 256, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 256, "k": 512, "split_k": 1}, + ] + + # Write to output file in C++ format + output_dir = Path(output_file).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_file, "w") as f: + f.write("// Generated test parameters for this configuration\n") + f.write("// This file is auto-generated during CMake configuration\n\n") + f.write("static const std::vector CONFIG_TEST_PARAMS = {\n") + + for i, params in enumerate(test_params): + comma = "," if i < len(test_params) - 1 else "" + f.write( + f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n" + ) + + f.write("};\n") + + print( + f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Extract test parameters from config JSON" + ) + parser.add_argument("--config_file", required=True, help="Input config JSON file") + parser.add_argument( + "--output_file", required=True, help="Output test parameters file" + ) + + args = parser.parse_args() + + if not os.path.exists(args.config_file): + print(f"Error: Config file not found: {args.config_file}") + return 1 + + extract_test_params(args.config_file, args.output_file) + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp b/projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp new file mode 100644 index 000000000000..e44e8c4182ac --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp @@ -0,0 +1,241 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file test_gemm_simple.cpp + * @brief Unit tests for GEMM kernels generated by gemm_instance_builder + * + * This test includes kernels generated during CMake configuration by + * gemm_instance_builder.py and tests them with problem sizes extracted + * from the corresponding JSON configuration files. + */ + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "tile_engine/ops/gemm/gemm_common.hpp" + +// The kernel header is included via compile command line with -include flag +// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types + +// Adaptive error threshold calculation matching tile_engine's implementation +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations (from tile_engine) +template +bool compare_results(std::string instanceName, + ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} + +// Test parameter structure for matrix dimensions and split_k values +struct GemmTestParams +{ + int m, n, k, split_k; +}; + +// Include config-specific test parameters (after GemmTestParams struct is defined) +#ifdef GEMM_TEST_PARAMS_HPP +#include GEMM_TEST_PARAMS_HPP +#endif + +class GemmTileEngineTest : public ::testing::TestWithParam +{ + protected: + void SetUp() override + { + auto params = GetParam(); + m_ = params.m; + n_ = params.n; + k_ = params.k; + split_k_ = params.split_k; + + // Calculate strides (following tile_engine pattern) + if constexpr(std::is_same_v) + { + stride_a_ = k_; + } + else + { + stride_a_ = m_; + } + + if constexpr(std::is_same_v) + { + stride_b_ = n_; + } + else + { + stride_b_ = k_; + } + + if constexpr(std::is_same_v) + { + stride_c_ = n_; + } + else + { + stride_c_ = m_; + } + } + + // Test dimensions + int m_, n_, k_, split_k_; + int stride_a_, stride_b_, stride_c_; +}; + +TEST_P(GemmTileEngineTest, BasicFunctionality) +{ + // Get tensor layouts from generated kernel + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + // Use split_k from test parameters + int split_k = split_k_; + int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a)); + int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b)); + int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c)); + + // Create host tensors with proper descriptors + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); + ck_tile::HostTensor c_m_n_host_result( + ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); + + // Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine) + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + + // Allocate GPU device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + // Copy data to device and zero output buffer + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + // Calculate reference result on host for verification + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + + // Create GEMM kernel arguments + ck_tile::GemmHostArgs gemm_args(a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + split_k, + m_, + n_, + k_, + stride_a_calc, + stride_b_calc, + stride_c_calc); + + // Configure kernel execution for maximum speed (no timing, no debug output) + ck_tile::stream_config stream_config{nullptr, // stream + false, // time_kernel (disable timing for speed) + 0, // log_level (disable debug output) + 0, // n_warmup + 1, // n_repeat + false, // is_gpu_timer (unused when time_kernel=false) + false, // flush_cache + 1}; // rotating_count + + // Launch the generated kernel (no timing overhead for fastest execution) + try + { + SelectedKernel::launch(gemm_args, stream_config); + // Kernel launched successfully if no exception thrown + } + catch(const std::exception& e) + { + std::string error_msg(e.what()); + // If arguments not supported, skip the test (configuration validation failure, not a bug) + if(error_msg.find("Arguments not supported") != std::string::npos) + { + GTEST_SKIP() << "Configuration not supported: " << e.what(); + } + else + { + FAIL() << "Kernel launch failed: " << e.what(); + } + } + + // Copy result back from device + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + // Verify results using tile_engine's adaptive error thresholds + bool verification_passed = compare_results( + KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result); + + EXPECT_TRUE(verification_passed) << "GEMM result verification failed"; +} + +TEST_P(GemmTileEngineTest, KernelInfo) +{ + // Simple test to verify kernel information is available + EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty"; + + std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; + std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_ + << std::endl; +} + +// Use config-specific test parameters (included via compile flags) +// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file +INSTANTIATE_TEST_SUITE_P(GemmVerification, + GemmTileEngineTest, + ::testing::ValuesIn(CONFIG_TEST_PARAMS), + [](const ::testing::TestParamInfo& param_info) { + return std::to_string(param_info.param.m) + "x" + + std::to_string(param_info.param.n) + "x" + + std::to_string(param_info.param.k) + "_splitk" + + std::to_string(param_info.param.split_k); + }); diff --git a/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt b/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt index c7f6e48930ad..b50a6790105a 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt +++ b/projects/composablekernel/tile_engine/ops/gemm/CMakeLists.txt @@ -15,7 +15,7 @@ if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") if(_te_budget GREATER 0) # Detect active ops from their DATATYPE variables set(_active_ops "") - foreach(_op gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) + foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) string(TOUPPER ${_op} _OP_UPPER) if(NOT "${${_OP_UPPER}_DATATYPE}" STREQUAL "") list(APPEND _active_ops ${_op}) @@ -45,7 +45,7 @@ if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") message(STATUS "Sampling budget allocation:\n${_alloc_output}") # Read per-op allocations (only if not already overridden) - foreach(_op gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) + foreach(_op gemm_universal gemm_multi_d gemm_preshuffle grouped_gemm gemm_streamk batched_contraction batched_gemm gemm_multi_abd mx_gemm gemm_rowcolquant gemm_tensor_quant grouped_gemm_rowcolquant grouped_gemm_tensorquant) string(TOUPPER ${_op} _OP_UPPER) if("${${_OP_UPPER}_MAX_INSTANCES}" STREQUAL "") if(EXISTS "${_alloc_dir}/${_op}_budget.txt") @@ -73,6 +73,7 @@ if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") endif() endif() +add_subdirectory(gemm_universal EXCLUDE_FROM_ALL) add_subdirectory(gemm_multi_d EXCLUDE_FROM_ALL) add_subdirectory(gemm_preshuffle EXCLUDE_FROM_ALL) add_subdirectory(grouped_gemm EXCLUDE_FROM_ALL) diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt new file mode 100644 index 000000000000..e0624b7067b2 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt @@ -0,0 +1,338 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)") +set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)") +set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +set(GEMM_UNIVERSAL_MAX_INSTANCES "" CACHE STRING "Max kernel instances per (dtype, layout) combo (empty = no cap)") +option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF) + +# Store the directory path for use in functions +set(GEMM_UNIVERSAL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM Universal targets +function(create_individual_gemm_universal_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM Universal target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") + return() + endif() + + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_universal_${datatype}_${layout}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "gemm_universal_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + --gpu_target "${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}" + DEPENDS ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + EXCLUDE_FROM_ALL + ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_benchmark_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_UNIVERSAL_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_UNIVERSAL_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_universal_all ${target_name}) + add_dependencies(benchmark_gemm_universal_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_universal_${layout} ${target_name}) + add_dependencies(benchmark_gemm_universal_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_universal_${pipeline}_pipeline ${target_name}) + add_dependencies(benchmark_gemm_universal_${epilogue}_epilogue ${target_name}) + add_dependencies(benchmark_gemm_universal_${scheduler}_scheduler ${target_name}) +endfunction() + +# Function to build individual GEMM Universal targets +function(build_individual_gemm_universal_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_UNIVERSAL_CONFIG_FILE + # 2. CMake variable GEMM_UNIVERSAL_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_UNIVERSAL_CONFIG_FILE} AND NOT "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(VERBOSE " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_UNIVERSAL_CONFIG_FILE}") + message(VERBOSE " Using custom config: ${GEMM_UNIVERSAL_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(VERBOSE " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(VERBOSE " Working path: ${working_path}") + message(VERBOSE " Config file: ${json_blob}") + message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") + message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL} + --list_kernels ") + + # Build optional args for instance builder + set(extra_list_args "") + if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "") + list(APPEND extra_list_args --max-instances ${GEMM_UNIVERSAL_MAX_INSTANCES}) + endif() + if(NOT "${TILE_ENGINE_SAMPLING_TIER}" STREQUAL "") + list(APPEND extra_list_args --tier ${TILE_ENGINE_SAMPLING_TIER}) + list(APPEND extra_list_args --manifest-path ${working_path}) + endif() + if(NOT "${TILE_ENGINE_SAMPLING_SEED}" STREQUAL "") + list(APPEND extra_list_args --seed ${TILE_ENGINE_SAMPLING_SEED}) + endif() + + # First, just list the kernels (fast operation) + message(VERBOSE " Listing kernel configurations...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL} + --list_kernels + ${extra_list_args} + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") + endif() + + # Read kernel count + if(EXISTS ${working_path}/gemm_universal_kernel_count.txt) + file(READ ${working_path}/gemm_universal_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(VERBOSE " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_universal_kernel_list.txt) + file(STRINGS ${working_path}/gemm_universal_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_universal_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") + endforeach() + else() + message(FATAL_ERROR "Kernel list file not found") + endif() +endfunction() + +# Main build logic - Only individual builds supported +message(VERBOSE "=== Starting Tile Engine GEMM Universal Configuration ===") +message(VERBOSE "GEMM_UNIVERSAL_DATATYPE: ${GEMM_UNIVERSAL_DATATYPE}") +message(VERBOSE "GEMM_UNIVERSAL_LAYOUT: ${GEMM_UNIVERSAL_LAYOUT}") +message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 +set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201;gfx12-generic") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL ${target}) + message(VERBOSE " Adding GPU target: ${target}") + endif() +endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM_UNIVERSAL) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(VERBOSE "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(VERBOSE "ccache disabled for GEMM Universal ops (use -DENABLE_CCACHE_GEMM_UNIVERSAL=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_universal_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) + add_custom_target(benchmark_gemm_universal_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) + add_custom_target(benchmark_gemm_universal_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) + foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) + add_custom_target(benchmark_gemm_universal_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM Universal kernels + set(GEMM_UNIVERSAL_PIPELINES "mem;compv3;compv4") + set(GEMM_UNIVERSAL_EPILOGUES "default;cshuffle") + set(GEMM_UNIVERSAL_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_UNIVERSAL_PIPELINES) + add_custom_target(benchmark_gemm_universal_${pipeline}_pipeline) + endforeach() + + foreach(epilogue IN LISTS GEMM_UNIVERSAL_EPILOGUES) + add_custom_target(benchmark_gemm_universal_${epilogue}_epilogue) + endforeach() + + foreach(scheduler IN LISTS GEMM_UNIVERSAL_SCHEDULERS) + add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler) + endforeach() + + # Divide MAX_INSTANCES budget across all active (dtype, layout) combos so that + # sampling fires per-combo rather than being a single cap larger than any combo's + # feasible set (which would make sampling a no-op for most combos). + if(NOT "${GEMM_UNIVERSAL_MAX_INSTANCES}" STREQUAL "") + list(LENGTH GEMM_UNIVERSAL_DATATYPE _gu_n_dt) + list(LENGTH GEMM_UNIVERSAL_LAYOUT _gu_n_lay) + math(EXPR _gu_n_combos "${_gu_n_dt} * ${_gu_n_lay}") + if(_gu_n_combos GREATER 0) + math(EXPR GEMM_UNIVERSAL_MAX_INSTANCES + "${GEMM_UNIVERSAL_MAX_INSTANCES} / ${_gu_n_combos}") + message(STATUS " gemm_universal: per-combo budget = ${GEMM_UNIVERSAL_MAX_INSTANCES} (${_gu_n_combos} combos)") + endif() + endif() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) + foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) + build_individual_gemm_universal_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp new file mode 100644 index 000000000000..23338a6cd008 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.hpp @@ -0,0 +1,73 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm/gemm_benchmark.hpp" + +#if __clang_major__ >= 23 +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" +#endif +// Data types and Layouts are defined by the generated kernel headers +// No hardcoded type definitions here to avoid conflicts + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) +{ + if(verify == 1) + { + c_m_n_host_result.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); + } +} +#if __clang_major__ >= 23 +#pragma clang diagnostic pop +#endif diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py new file mode 100755 index 000000000000..73ba1261a849 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import os +import sys +import argparse +import time +import importlib.util + + +def _import_gemm_benchmark(): + """Import gemm benchmark from parent directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "gemm_benchmark", + os.path.join(parent_dir, "gemm_benchmark.py"), + ) + gemm_benchmark_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gemm_benchmark_module) + + return gemm_benchmark_module.GemmBenchmark + + +def _import_benchmark_utils(): + """Import benchmark utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(os.path.dirname(current_dir)) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "benchmark_utils", + os.path.join(parent_dir, "common", "benchmark_utils.py"), + ) + benchmark_utils = importlib.util.module_from_spec(spec) + spec.loader.exec_module(benchmark_utils) + + return benchmark_utils + + +GemmBenchmark = _import_gemm_benchmark() +benchmark_utils = _import_benchmark_utils() + + +class GemmUniversalBenchmark(GemmBenchmark): + def __init__(self, build_dir: str, verbose: bool = False): + super().__init__(build_dir, verbose, name="benchmark_gemm_universal_") + + +def main(): + parser = argparse.ArgumentParser( + description="Universal GEMM Kernel Benchmarking Tool" + ) + parser.add_argument( + "build_dir", help="Build directory containing kernel executables" + ) + parser.add_argument( + "--problem-sizes", + nargs="+", + default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"], + help="Problem sizes as M,N,K tuples", + ) + parser.add_argument( + "--split-k", nargs="+", type=int, default=[1], help="Split-K values to test" + ) + parser.add_argument("--verify", action="store_true", help="Enable verification") + parser.add_argument( + "--csv", + default="gemm_universal_benchmark_results.csv", + help="CSV output filename", + ) + parser.add_argument( + "--best", default="best_kernels.txt", help="Best kernels output filename" + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--warmup", + type=int, + default=50, + help="Number of warmup iterations (default: 50)", + ) + parser.add_argument( + "--repeat", + type=int, + default=100, + help="Number of benchmark iterations (default: 100)", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + default=True, + help="Enable cache flushing (default: True)", + ) + parser.add_argument( + "--rotating-count", + type=int, + default=1000, + help="Number of iterations to rotate cache (default: 1000)", + ) + parser.add_argument("--json", help="JSON output filename (optional)") + + args = parser.parse_args() + + # Parse problem sizes + problem_sizes = [] + for size_str in args.problem_sizes: + try: + m, n, k = map(int, size_str.split(",")) + problem_sizes.append((m, n, k)) + except ValueError: + print(f"Invalid problem size: {size_str}") + return 1 + + # Create benchmark instance + benchmark = GemmUniversalBenchmark(args.build_dir, verbose=args.verbose) + + # Run benchmark sweep + print("Starting Universal GEMM kernel benchmark sweep...") + start_time = time.time() + + best_kernels = benchmark.benchmark_sweep( + problem_sizes=problem_sizes, + split_k_values=args.split_k, + verify=args.verify, + warmup=args.warmup, + repeat=args.repeat, + flush_cache=args.flush_cache, + rotating_count=args.rotating_count, + ) + + elapsed_time = time.time() - start_time + print(f"\nBenchmark completed in {elapsed_time:.2f} seconds") + + # Export results + benchmark_utils.export_csv(benchmark.results, args.csv) + benchmark_utils.export_best_kernels(best_kernels, args.best) + + # Export JSON if requested + if args.json: + benchmark_utils.export_json(benchmark.results, args.json, best_kernels) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp new file mode 100644 index 000000000000..9e73077e2895 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark_single.cpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm/gemm_common.hpp" +#include "gemm_universal_profiler.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME + +void benchmark_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + + // Create GemmProblem struct + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_c, + arg_parser.get_bool("structured_sparsity")}; + + // Create Settings struct + Settings setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + // Get the profiler instance + auto& profiler = UniversalGemmProfiler::GemmProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + auto kernel_func = [](const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_problem, kernel_func); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + + benchmark_single(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py new file mode 100644 index 000000000000..0d13584ca065 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py @@ -0,0 +1,344 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import os +import argparse +import importlib.util +import multiprocessing +import concurrent.futures + + +def _import_gemm_kernel_builder(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "gemm_instance_builder", + os.path.join(parent_dir, "gemm_instance_builder.py"), + ) + gemm_builder_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gemm_builder_module) + + return gemm_builder_module.GemmKernelBuilder + + +GemmKernelBuilder = _import_gemm_kernel_builder() + + +class GemmUniversalKernelBuilder(GemmKernelBuilder): + def __init__( + self, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json=None, + max_instances=None, + seed=None, + tier=None, + manifest_path=None, + ): + super().__init__( + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json, + max_instances=max_instances, + seed=seed, + tier=tier, + manifest_path=manifest_path, + ) + + def _generate_all_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.kernel_name_prefix, + self.working_path, + self.gpu_target, + self.datatype, + self.layout, + self.config_json, + ) + ) + + # Apply RFC-compliant sampling (Sobol + LHS + maximin) + if self.max_instances is not None and len(work_items) > self.max_instances: + kernel_dicts = [ + {"tile_config": item[0], "trait_combo": item[1], "_work_item": item} + for item in work_items + ] + sampled = self._apply_sampling(kernel_dicts) + work_items = [k["_work_item"] for k in sampled] + + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + ( + tile_config, + trait_combo, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json, + ) = work_item + + # Create a temporary builder instance for this worker + builder = GemmUniversalKernelBuilder( + kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json + ) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_universal_" prefix + # Remove "gemm_universal_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_universal_"): + simplified_name = simplified_name[ + len(kernel_name_prefix) + 1 : + ] # Remove "gemm_universal" prefix + + # Write individual header file + header_file = working_path / f"gemm_universal_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Universal kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument( + "--gpu_target", + required=True, + help="GPU target architecture", + ) + parser.add_argument( + "--datatype", + required=True, + choices=["fp16", "fp8", "bf16", "bf8"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + parser.add_argument( + "--max-instances", + type=int, + default=None, + help="Cap on number of kernel instances per (dtype, layout) combo", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="RNG seed for deterministic sampling; if omitted, derived from today's date", + ) + parser.add_argument( + "--tier", + default=None, + help="Sampling tier (daily/weekly)", + ) + parser.add_argument( + "--manifest-path", + default=None, + help="Directory for chosen_instances.json", + ) + + args = parser.parse_args() + + assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 3, ( + f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + + kernel_name_prefix = "gemm_universal" + builder = GemmUniversalKernelBuilder( + kernel_name_prefix, + args.working_path, + args.gpu_target, + args.datatype, + args.layout, + args.config_json, + max_instances=args.max_instances, + seed=args.seed, + tier=args.tier, + manifest_path=args.manifest_path, + ) + + if args.list_kernels: + builder._list_kernels() + elif args.gen_single: + # Generate a single kernel file input validation + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + # Generate the kernel + builder._generate_kernel_instance( + tile_config, + trait_combo, + ) + elif args.gen_all_individual: + # Generate all individual kernel files + builder._generate_all_individual(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp new file mode 100644 index 000000000000..6eb4266aae88 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/gemm_universal_profiler.hpp @@ -0,0 +1,147 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "gemm/gemm_benchmark.hpp" +#include "gemm/gemm_profiler.hpp" +#include "gemm_universal_benchmark.hpp" + +class UniversalGemmProfiler + : public GemmProfiler +{ + public: + using BaseGemm = GemmProfiler; + using BaseGemm::benchmark; + + UniversalGemmProfiler(Settings setting) + : GemmProfiler(setting) + { + } + + void benchmark(GemmProblem& gemm_problem, + std::vector( + ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) override + { + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + gemm_problem.stride_a_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); + gemm_problem.stride_b_ = ck_tile::get_default_stride( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); + gemm_problem.stride_c_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); + + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(setting_.init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(setting_.init_method == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(gemm_problem.structured_sparsity_) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + ck_tile::permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs gemm_args = { + a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + gemm_problem.split_k_, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_, + }; + + ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.verify) + { + gemm_host_reference(setting_.verify, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_); + } + + for(auto& callable : callables) + { + auto kernel_run_result = callable(gemm_args, + ck_tile::stream_config{nullptr, + true, + setting_.log, + setting_.n_warmup, + setting_.n_repeat, + setting_.is_gpu_timer, + setting_.flush_cache, + setting_.rotating_count}); + process_result(gemm_problem, + c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + kernel_run_result); + } + } +}; From bf778923e0e17a54ee932d290e07e93c3367b943 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Fri, 26 Jun 2026 18:15:37 -0400 Subject: [PATCH 13/16] [CK_TILE] gemm bridge: add missing copyright headers + drop trailing whitespace - Add AMD copyright/SPDX header to gemm_full_benchmark.py and run_one_gemm_kernel.py (CK requires a header on every source file). - Remove a trailing-whitespace blank line in generated_tile_backend.hpp that would trip the whitespace/clang-format CI gate. --- .../ck_tile/dispatcher/backends/generated_tile_backend.hpp | 1 - .../tile_engine/ops/gemm/gemm_full_benchmark.py | 2 ++ .../tile_engine/ops/gemm/run_one_gemm_kernel.py | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp index ff354f5523b2..79e619105a21 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -123,7 +123,6 @@ class GeneratedTileKernelInstance : public KernelInstance problem.N // stride_E/C (row-major C: stride = N) ); - const bool bench = this->benchmarking_; ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = reinterpret_cast(stream); diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py index 0228d17903e0..f6bdecfae56c 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT """Full GEMM benchmark sweep driven through the Dispatcher bridge. Phases: diff --git a/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py index e1638297b82f..54e138198a96 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py +++ b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT """Worker script for running GEMM kernels in an isolated subprocess. Mirrors grouped_conv's run_one_grouped_conv_kernel.py: From a7e01c4429c0b86e99b9aedd2130f5ee0477c73a Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Fri, 26 Jun 2026 18:38:55 -0400 Subject: [PATCH 14/16] [CK_TILE] gemm codegen: reject invalid non-power-of-2-repeat tiles (e.g. 192) The CShuffle epilogue stores the accumulator back through LDS in power-of-two MRepeat/NRepeat chunks, where MRepeat = tile_m / (wave_m * warp_tile_m) (and likewise N). A tile whose per-wave repeat is not a power of two (or whose tile dim is not divisible by wave*warp_tile) is mis-stored and produces numerically WRONG results at runtime -- yet it still passes the ctypes validator and the epilogue's static_asserts, so it compiles and silently returns garbage. Observed on MI350 for tile_m=192 (MRepeat = 192/(2*32) = 3) and tile_n=192 (e.g. 64x192x64_1x4x1, 192 not divisible by 4*32): both verified incorrect (fp32 reference, max_rel ~1.2-1.4) on the bridge AND Tile Engine, at every shape including shapes divisible by 192. Power-of-two tiles (64/128/256) are unaffected; a control 256-tile verifies cleanly (max_rel ~4e-4). Add a validity gate in both tile-expansion paths: * unified_gemm_codegen.py::_get_tile_configs (codegen CLI path) * gemm_utils.py::expand_sweep (bridge .so build path; this path only ran the ctypes validate_kernel_config, which does not catch this) so invalid tiles are dropped instead of emitted/run. tile_k is unaffected (the K reduction has no CShuffle store constraint). --- .../codegen/unified_gemm_codegen.py | 20 ++++++++++++++++++ .../dispatcher/python/gemm_utils.py | 21 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py index ec525ddd5c4c..0dd2280c159c 100755 --- a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py +++ b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py @@ -187,6 +187,10 @@ def is_preshuffle_config_valid( log = logging.getLogger(__name__) +def _is_power_of_two(x: int) -> bool: + return x > 0 and (x & (x - 1)) == 0 + + # ============================================================================ # Configuration and Data Structures # ============================================================================ @@ -1148,6 +1152,22 @@ def _get_tile_configs(self) -> List[TileConfig]: rejected_count += 1 continue + # CShuffle-store correctness gate. The CShuffle epilogue stores the + # accumulator back through LDS in power-of-two MRepeat/NRepeat chunks, + # so a tile whose per-wave repeat count -- tile / (warp * warp_tile) -- + # is not a power of two is mis-stored and yields numerically WRONG + # results at runtime. The kernel still compiles (the epilogue's + # static_asserts only check divisibility, which such tiles satisfy), + # so it must be filtered here. Observed on MI350 for tile_m=192 + # (MRepeat = 192 / (2*32) = 3): verified incorrect on BOTH the bridge + # and Tile Engine at every shape, including shapes divisible by 192. + # Power-of-two tiles (64/128/256) are unaffected. + m_repeat = tile.tile_m // (tile.warp_m * tile.warp_tile_m) + n_repeat = tile.tile_n // (tile.warp_n * tile.warp_tile_n) + if not (_is_power_of_two(m_repeat) and _is_power_of_two(n_repeat)): + rejected_count += 1 + continue + # Architecture-specific validation. This is a pre-filter run before # tiles are paired with traits, so keep a tile if it is legal under # ANY configured pipeline/scheduler; the precise per-trait check diff --git a/projects/composablekernel/dispatcher/python/gemm_utils.py b/projects/composablekernel/dispatcher/python/gemm_utils.py index e02635e1c183..ed4d11fa1ada 100644 --- a/projects/composablekernel/dispatcher/python/gemm_utils.py +++ b/projects/composablekernel/dispatcher/python/gemm_utils.py @@ -763,6 +763,10 @@ def _expand_values(entry: Optional[Dict[str, Any]], default: List[Any]) -> List[ return list(entry.get("values", default)) +def _is_power_of_two(x: int) -> bool: + return x > 0 and (x & (x - 1)) == 0 + + def expand_sweep( config_path: str, arch: str, @@ -873,6 +877,23 @@ def expand_sweep( val = _cu.validate_kernel_config(c.to_ctypes_config()) if not val.is_valid: continue + # Tile/CShuffle correctness gate (mirrors unified_gemm_codegen's + # TileConfig.is_valid + the power-of-two repeat rule; the ctypes + # validate_kernel_config above does NOT enforce either). A block tile must + # split evenly across its waves -- tile % (wave * warp_tile) == 0 -- and + # the CShuffle epilogue stores the accumulator through LDS in power-of-two + # MRepeat/NRepeat chunks, so the per-wave repeat must be a power of two. + # Tiles that violate either still compile but produce numerically WRONG + # results at runtime. Observed on MI350 for tile_m=192 (MRepeat=3) and + # tile_n=192 (e.g. 64x192x64_1x4x1, 192 not divisible by 4*32) -- both + # verified incorrect on the bridge and Tile Engine. Power-of-two tiles + # (64/128/256) are unaffected. + m_div = wm * wtm + n_div = wn * wtn + if m_div <= 0 or n_div <= 0 or tm % m_div != 0 or tn % n_div != 0: + continue + if not _is_power_of_two(tm // m_div) or not _is_power_of_two(tn // n_div): + continue seen.add(c.name) configs.append(c) From e30be7677828a5891b489613242a41efa867aa0e Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Sat, 27 Jun 2026 18:38:49 -0400 Subject: [PATCH 15/16] [CK_TILE] gemm bridge: extend TE<->Dispatcher bridge to fp8/bf8/int8 (all layouts) Adds the remaining data types Tile Engine's plain GEMM has MFMA warp tiles for beyond the fp16/bf16 surface of PR #8479: fp8 (E4M3) and bf8 (E5M2) accumulating into fp16, and int8 accumulating into int32 (gfx942). Covers all four A/B layout combinations per dtype (row-major C only, as ck_tile rejects column-major C). Codegen (codegen_common.py, unified_gemm_codegen.py): - add int32 to the CK / qualified / dispatcher dtype maps - get_output_dtype: int8 -> int32 (fp8/bf8 -> fp16 unchanged) - new get_acc_dtype: int8 -> int32, else fp32 - derive AccDataType, CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros and the registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32 Host harness (gemm_utils.py): - fp8/bf8 FNUZ (gfx942) uint8 codecs: exact decode (matches device fp8_t/bf8_t), nearest-representable saturating encode, mirroring the existing bf16 helper - GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype (fp16 for fp8/bf8, int32 for int8) - expand_sweep sets dtype_c/dtype_acc from the input dtype Tests: - test_gemm_utils.py: fp8/bf8 codec round-trip, format ranges, NaN/zero slots, saturation, byte size; output-dtype mapping (CPU-only) - test_gemm_parity.py: fp8/bf8/int8 cases with dtype-aware inputs, references and tolerances (int8 exact); GPU-gated like the existing fp16/bf16 cases GPU parity validation deferred to a follow-up run on an MI300X node. --- .../dispatcher/codegen/codegen_common.py | 25 +++- .../codegen/unified_gemm_codegen.py | 15 +- .../dispatcher/python/gemm_utils.py | 133 +++++++++++++++++- .../dispatcher/tests/test_gemm_parity.py | 108 ++++++++++---- .../dispatcher/tests/test_gemm_utils.py | 78 +++++++++- 5 files changed, 313 insertions(+), 46 deletions(-) diff --git a/projects/composablekernel/dispatcher/codegen/codegen_common.py b/projects/composablekernel/dispatcher/codegen/codegen_common.py index a0486da66d74..a5f022021f8c 100644 --- a/projects/composablekernel/dispatcher/codegen/codegen_common.py +++ b/projects/composablekernel/dispatcher/codegen/codegen_common.py @@ -118,6 +118,7 @@ class CommonTypeMappings: "fp8": "fp8_t", "bf8": "bf8_t", "int8": "int8_t", + "int32": "int32_t", } DTYPE_TO_CK_QUALIFIED = { @@ -127,6 +128,7 @@ class CommonTypeMappings: "fp8": "ck_tile::fp8_t", "bf8": "ck_tile::bf8_t", "int8": "int8_t", + "int32": "int32_t", } DTYPE_TO_DISPATCHER = { @@ -136,6 +138,7 @@ class CommonTypeMappings: "fp8": "DataType::FP8", "bf8": "DataType::BF8", "int8": "DataType::INT8", + "int32": "DataType::INT32", } # GEMM-specific layout mappings ("r"/"c" for row/column major). @@ -202,8 +205,26 @@ class CommonTypeMappings: @staticmethod def get_output_dtype(dtype: str) -> str: - """Get output datatype (fp8/bf8 -> fp16).""" - return "fp16" if dtype in ("fp8", "bf8") else dtype + """Get output (C) datatype for an A/B element dtype. + + Low-precision float inputs accumulate into and store as fp16 + (fp8/bf8 -> fp16); int8 stores its int32 accumulator (int8 -> int32). + Everything else stores in its own dtype. + """ + if dtype in ("fp8", "bf8"): + return "fp16" + if dtype == "int8": + return "int32" + return dtype + + @staticmethod + def get_acc_dtype(dtype: str) -> str: + """Get accumulator datatype for an A/B element dtype. + + Integer GEMM accumulates in int32; every float dtype accumulates in + fp32. + """ + return "int32" if dtype == "int8" else "fp32" # ============================================================================ diff --git a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py index 0dd2280c159c..e40abbd1097f 100755 --- a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py +++ b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py @@ -414,12 +414,13 @@ def _types(self, config: KernelConfig, kernel_name: str) -> str: def _kernel_local_types(self, config: KernelConfig) -> str: """Generate data type and layout definitions inside kernel namespace""" output_dtype = self.tm.get_output_dtype(self.datatype) + acc_dtype = self.tm.get_acc_dtype(self.datatype) return f""" // Data types (inside namespace to avoid conflicts across layouts) using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; - using AccDataType = float; + using AccDataType = {self.tm.DTYPE_TO_CK[acc_dtype]}; using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; // Layouts (inside namespace to avoid conflicts when mixing layouts) @@ -452,6 +453,7 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str t = config.tile tr = config.trait output_dtype = self.tm.get_output_dtype(self.datatype) + acc_dtype = self.tm.get_acc_dtype(self.datatype) # Generate unique struct name and namespace from kernel name struct_name = f"Kernel_{kernel_name}" @@ -467,7 +469,7 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str // Data types (inside namespace to avoid conflicts across different kernels) using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; -using AccDataType = float; +using AccDataType = {self.tm.DTYPE_TO_CK[acc_dtype]}; using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; // Layouts (inside namespace to avoid conflicts when mixing layouts like RCR + RRR) @@ -522,8 +524,8 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str constexpr const char* KERNEL_NAME = {ns_name}::KERNEL_NAME; using ADataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; -using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]}; -using AccDataType = float; +using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[output_dtype]}; +using AccDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[acc_dtype]}; // KernelKey field descriptors for the force-included kernel. // The ctypes library builds the registry KernelKey from these so the @@ -534,7 +536,7 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str #define GEMM_KEY_DTYPE_A "{self.datatype}" #define GEMM_KEY_DTYPE_B "{self.datatype}" #define GEMM_KEY_DTYPE_C "{output_dtype}" -#define GEMM_KEY_DTYPE_ACC "fp32" +#define GEMM_KEY_DTYPE_ACC "{acc_dtype}" #define GEMM_KEY_LAYOUT_A "{self.layout[0]}" #define GEMM_KEY_LAYOUT_B "{self.layout[1]}" #define GEMM_KEY_LAYOUT_C "{self.layout[2]}" @@ -815,6 +817,7 @@ def generate( """Generate dispatcher wrapper""" kernel_name = KernelNaming.generate(config, self.datatype, self.layout) output_dtype = self.tm.get_output_dtype(self.datatype) + acc_dtype = self.tm.get_acc_dtype(self.datatype) rel_path = kernel_path.relative_to(output_dir) return f"""// SPDX-License-Identifier: MIT @@ -849,7 +852,7 @@ def generate( key.signature.dtype_a = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; key.signature.dtype_b = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; key.signature.dtype_c = {self.tm.DTYPE_TO_DISPATCHER[output_dtype]}; - key.signature.dtype_acc = DataType::FP32; + key.signature.dtype_acc = {self.tm.DTYPE_TO_DISPATCHER[acc_dtype]}; key.signature.layout_a = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[0]]}; key.signature.layout_b = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[1]]}; key.signature.layout_c = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[2]]}; diff --git a/projects/composablekernel/dispatcher/python/gemm_utils.py b/projects/composablekernel/dispatcher/python/gemm_utils.py index ed4d11fa1ada..ead6baf648da 100644 --- a/projects/composablekernel/dispatcher/python/gemm_utils.py +++ b/projects/composablekernel/dispatcher/python/gemm_utils.py @@ -400,6 +400,107 @@ def _bf16_u16_to_fp32(u16: np.ndarray) -> np.ndarray: return (u16.astype(np.uint32) << 16).view(np.float32) +# --------------------------------------------------------------------------- +# fp8 (E4M3) / bf8 (E5M2) -- FNUZ ("NANOO") encoding used by gfx942/MI300. +# +# numpy has no native 8-bit float, and the C ABI only cares about the 1-byte +# memory layout (sizeof(fp8_t) == sizeof(bf8_t) == 1). We carry the value as a +# uint8 bit pattern. As with bf16, the DECODE is the load-bearing half: it must +# return the exact value the device's fp8_t/bf8_t represents for a byte, so the +# NumPy reference multiplies bit-for-bit what the GPU multiplies. The ENCODE only +# needs to land on the nearest representable byte. +# +# FNUZ format (gfx942): bias = 2^(exp_bits-1); the all-1s exponent is a normal +# number (no Inf), the sole NaN is the sign=1/exp=0/mant=0 byte (0x80), and there +# is no negative zero. gfx950/MI350 uses the OCP fp8 format instead; this codec +# targets the gfx942 default and the OCP path needs separate handling. +# --------------------------------------------------------------------------- + + +def _fnuz_decode_table(exp_bits: int, mant_bits: int) -> np.ndarray: + """Build the 256-entry byte -> fp32 value table for an 8-bit FNUZ float.""" + bias = (1 << (exp_bits - 1)) + mant_max = 1 << mant_bits + sign_shift = exp_bits + mant_bits + exp_mask = (1 << exp_bits) - 1 + table = np.zeros(256, dtype=np.float32) + for b in range(256): + sign = (b >> sign_shift) & 1 + exp = (b >> mant_bits) & exp_mask + mant = b & (mant_max - 1) + if exp == 0 and mant == 0: + # +0 (0x00); the negative-zero slot (0x80) is the lone NaN. + table[b] = np.float32(np.nan) if sign else np.float32(0.0) + continue + if exp == 0: + val = (mant / mant_max) * (2.0 ** (1 - bias)) # subnormal + else: + val = (1.0 + mant / mant_max) * (2.0 ** (exp - bias)) # normal + table[b] = np.float32(-val if sign else val) + return table + + +def _fnuz_encode(x: np.ndarray, exp_bits: int, mant_bits: int) -> np.ndarray: + """Encode fp32 -> nearest 8-bit FNUZ float, returned as a uint8 bit pattern.""" + table = _fnuz_decode_table(exp_bits, mant_bits) + sign_byte = np.uint8(1 << (exp_bits + mant_bits)) # 0x80 + + # Positive half (bytes 0..127) holds every non-negative magnitude, sorted. + # Compare in float64: for very large inputs the gap between the two top + # magnitudes is below fp32 resolution, which would tie and mis-saturate. + pos_mag = table[: int(sign_byte)].astype(np.float64) + order = np.argsort(pos_mag) + sorted_mag = pos_mag[order] + sorted_byte = order.astype(np.uint8) + + xf = np.ascontiguousarray(x, dtype=np.float32) + ax = np.abs(xf).astype(np.float64) + # Both neighbours come from the raw insertion point: raw==size saturates to + # the top magnitude (lo==hi), raw==0 pins to zero, otherwise compare the two. + raw = np.searchsorted(sorted_mag, ax) + hi = np.clip(raw, 0, sorted_mag.size - 1) + lo = np.clip(raw - 1, 0, sorted_mag.size - 1) + pick_lo = np.abs(sorted_mag[lo] - ax) <= np.abs(sorted_mag[hi] - ax) + chosen = np.where(pick_lo, lo, hi) + out = sorted_byte[chosen] + + # Apply sign, but never the 0x80 (-0 == NaN) slot: zeros stay +0. + is_zero = sorted_mag[chosen] == 0 + out = np.where((xf < 0) & ~is_zero, out | sign_byte, out) + out = np.where(np.isnan(xf), sign_byte, out) # NaN inputs -> NaN byte + return out.astype(np.uint8).reshape(np.shape(x)) + + +def _fp32_to_fp8_u8(x: np.ndarray) -> np.ndarray: + """Encode fp32 -> fp8 E4M3 (FNUZ) bit pattern in a uint8 array.""" + return _fnuz_encode(x, exp_bits=4, mant_bits=3) + + +def _fp8_u8_to_fp32(u8: np.ndarray) -> np.ndarray: + """Decode an fp8 E4M3 (FNUZ) bit pattern back to fp32.""" + return _fnuz_decode_table(4, 3)[u8.astype(np.intp)] + + +def _fp32_to_bf8_u8(x: np.ndarray) -> np.ndarray: + """Encode fp32 -> bf8 E5M2 (FNUZ) bit pattern in a uint8 array.""" + return _fnuz_encode(x, exp_bits=5, mant_bits=2) + + +def _bf8_u8_to_fp32(u8: np.ndarray) -> np.ndarray: + """Decode a bf8 E5M2 (FNUZ) bit pattern back to fp32.""" + return _fnuz_decode_table(5, 2)[u8.astype(np.intp)] + + +# Output (C) element dtype for an A/B element dtype, mirroring the codegen's +# CommonTypeMappings.get_output_dtype: fp8/bf8 accumulate into fp16, int8 into +# int32, everything else stores in its own dtype. +_OUTPUT_DTYPE = {"fp8": "fp16", "bf8": "fp16", "int8": "int32"} + + +def _output_dtype(dtype: str) -> str: + return _OUTPUT_DTYPE.get(dtype, dtype) + + def _dtype_from_kernel_name(name: str) -> str: """Extract the dtype token from a kernel name like ``gemm___...``.""" parts = name.split("_") @@ -461,20 +562,39 @@ def run( B_lay = B if lb == "r" else B.T C_shape = (M, N) if lc == "r" else (N, M) + # Build A/B host buffers in the kernel's element dtype. The encode + # helpers (bf16/fp8/bf8) already force a contiguous float32 source, so an + # outer ascontiguousarray would only add a redundant copy; the native + # numpy dtypes (fp16/int8) still need it. if dtype == "bf16": - # _fp32_to_bf16_u16 already forces a contiguous float32 buffer, so - # an outer ascontiguousarray here would only add a redundant copy. A_h = _fp32_to_bf16_u16(A_lay) B_h = _fp32_to_bf16_u16(B_lay) - C_h = np.zeros(C_shape, dtype=np.uint16) + elif dtype == "fp8": + A_h = _fp32_to_fp8_u8(A_lay) + B_h = _fp32_to_fp8_u8(B_lay) + elif dtype == "bf8": + A_h = _fp32_to_bf8_u8(A_lay) + B_h = _fp32_to_bf8_u8(B_lay) + elif dtype == "int8": + A_h = np.ascontiguousarray(A_lay, dtype=np.int8) + B_h = np.ascontiguousarray(B_lay, dtype=np.int8) else: # fp16 (default) A_h = np.ascontiguousarray(A_lay, dtype=np.float16) B_h = np.ascontiguousarray(B_lay, dtype=np.float16) - C_h = np.zeros(C_shape, dtype=np.float16) + + # The C buffer's element size must equal sizeof(CDataType): fp8/bf8 + # accumulate into fp16, int8 into int32, otherwise the input dtype. + out_dtype = _output_dtype(dtype) + _C_NP = {"fp16": np.float16, "bf16": np.uint16, "int32": np.int32} + C_h = np.zeros(C_shape, dtype=_C_NP.get(out_dtype, np.float16)) status, time_ms = self.lib.run(A_h, B_h, C_h, M, N, K) - C_dec = _bf16_u16_to_fp32(C_h) if dtype == "bf16" else C_h + # Decode the output back to a comparable numeric array. + if out_dtype == "bf16": + C_dec = _bf16_u16_to_fp32(C_h) + else: # fp16 / int32 are already directly comparable + C_dec = C_h C_out = C_dec if lc == "r" else C_dec.T tflops = (problem.flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0.0 @@ -850,7 +970,8 @@ def expand_sweep( c = GemmKernelConfig( dtype_a=dtype, dtype_b=dtype, - dtype_c=dtype, + dtype_c=_output_dtype(dtype), + dtype_acc=("int32" if dtype == "int8" else "fp32"), layout_a=_LAYOUT_WORD[la], layout_b=_LAYOUT_WORD[lb], layout_c=_LAYOUT_WORD[lc], diff --git a/projects/composablekernel/dispatcher/tests/test_gemm_parity.py b/projects/composablekernel/dispatcher/tests/test_gemm_parity.py index 308c5700672c..b9d1cd1cd9ab 100644 --- a/projects/composablekernel/dispatcher/tests/test_gemm_parity.py +++ b/projects/composablekernel/dispatcher/tests/test_gemm_parity.py @@ -44,21 +44,27 @@ setup_multiple_gemm_dispatchers, _fp32_to_bf16_u16, _bf16_u16_to_fp32, + _fp32_to_fp8_u8, + _fp8_u8_to_fp32, + _fp32_to_bf8_u8, + _bf8_u8_to_fp32, + _output_dtype, ) from ctypes_utils import detect_gpu_arch, get_build_dir # noqa: E402 # (dtype, layout) surface the regular bridge supports. Column-major C is rejected # by ck_tile's universal GEMM at build, so every layout keeps row-major C, which -# leaves exactly the four A/B combinations below. Both dtypes cover all four. +# leaves exactly the four A/B combinations below. Every dtype covers all four. +# +# fp16/bf16 are the PR #8479 surface; fp8 (E4M3), bf8 (E5M2) and int8 are the +# remaining dtypes TE's plain GEMM has MFMA warp tiles for (fp8/bf8 -> fp16 out, +# int8 -> int32 out). int8 only has warp tiles on gfx942; on other arches its +# kernels simply fail to build and the case skips (handled below). +_FLOAT_DTYPES = ("fp16", "bf16", "fp8", "bf8") +_INT_DTYPES = ("int8",) +_LAYOUTS = ("rcr", "rrr", "ccr", "crr") _CASES = [ - ("fp16", "rcr"), - ("fp16", "rrr"), - ("fp16", "ccr"), - ("fp16", "crr"), - ("bf16", "rcr"), - ("bf16", "rrr"), - ("bf16", "ccr"), - ("bf16", "crr"), + (dt, lay) for dt in (*_FLOAT_DTYPES, *_INT_DTYPES) for lay in _LAYOUTS ] # Padded default algorithm: pad_* all True so M/N need not divide the tile, which @@ -81,25 +87,72 @@ ("awkward", 257, 129, 512), ] -# Global-relative-error gates. fp16 measured ~3-4e-4 and bf16 ~8e-3 on gfx942; -# these leave headroom without masking a real regression. -_TOL = {"fp16": 2e-3, "bf16": 1.5e-2} +# Global-relative-error gates. fp16 measured ~3-4e-4 and bf16 ~8e-3 on gfx942. +# fp8/bf8 are far coarser (3- and 2-bit mantissa) so their gates are looser; int8 +# is an exact integer accumulation so it must match bit-for-bit. The fp8/bf8 +# gates are first-cut headroom values and may want tightening once measured on a +# GPU. +_TOL = { + "fp16": 2e-3, + "bf16": 1.5e-2, + "fp8": 1.5e-1, + "bf8": 3.0e-1, + "int8": 0.0, +} _LAYOUT_WORD = {"r": "row", "c": "col"} -def _emulate(x: np.ndarray, dtype: str) -> np.ndarray: - """Round fp32 to the kernel's storage dtype so the CPU reference matches what - the GPU actually multiplies (and stores).""" +def _emulate_input(x: np.ndarray, dtype: str) -> np.ndarray: + """Round an fp32 operand to the kernel's storage dtype so the CPU reference + multiplies exactly what the GPU does. int8 inputs are already integral.""" if dtype == "bf16": return _bf16_u16_to_fp32(_fp32_to_bf16_u16(x)) + if dtype == "fp8": + return _fp8_u8_to_fp32(_fp32_to_fp8_u8(x)) + if dtype == "bf8": + return _bf8_u8_to_fp32(_fp32_to_bf8_u8(x)) + if dtype == "int8": + return x.astype(np.float64) # exact; widened to avoid product overflow return x.astype(np.float16).astype(np.float32) +def _emulate_output(c: np.ndarray, out_dtype: str) -> np.ndarray: + """Round the fp32 accumulator to the kernel's C storage dtype.""" + if out_dtype == "bf16": + return _bf16_u16_to_fp32(_fp32_to_bf16_u16(c)) + if out_dtype == "int32": + return c # integer accumulation is exact + return c.astype(np.float16).astype(np.float32) # fp16 + + +def _make_inputs(dtype, M, N, K, rng): + """Random A (MxK), B (KxN) for a dtype: floats for the float dtypes, small + integers for int8 (kept small so the int32 accumulation cannot overflow).""" + if dtype == "int8": + A = rng.integers(-4, 5, size=(M, K)).astype(np.float32) + B = rng.integers(-4, 5, size=(K, N)).astype(np.float32) + return A, B + A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + return A, B + + +def _reference(A, B, dtype): + """NumPy reference matching the kernel: round inputs to the storage dtype, + accumulate (fp32 for floats / exact int for int8), then round to C dtype.""" + out_dtype = _output_dtype(dtype) + acc = _emulate_input(A, dtype) @ _emulate_input(B, dtype) + ref = _emulate_output(acc, out_dtype) + return ref.astype(np.int32) if out_dtype == "int32" else ref + + def _config(dtype: str, layout: str, arch: str) -> GemmKernelConfig: la, lb, lc = layout return GemmKernelConfig( - dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + dtype_a=dtype, dtype_b=dtype, + dtype_c=_output_dtype(dtype), + dtype_acc=("int32" if dtype == "int8" else "fp32"), layout_a=_LAYOUT_WORD[la], layout_b=_LAYOUT_WORD[lb], layout_c=_LAYOUT_WORD[lc], gfx_arch=arch, **_ALGO, ) @@ -161,17 +214,13 @@ def _run_case(self, dtype, layout, shape): _, M, N, K = shape problem = GemmProblem(M=M, N=N, K=K) rng = np.random.default_rng(42) - A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) - B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + A, B = _make_inputs(dtype, M, N, K, rng) runner = GpuGemmRunner(lib_path=so) # The .so is the contract endpoint: the name it reports must be the config - # name that drove codegen + the force-include build. - self.assertEqual(runner.kernel_name, GemmKernelConfig( - dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, - layout_a=_LAYOUT_WORD[layout[0]], layout_b=_LAYOUT_WORD[layout[1]], - layout_c=_LAYOUT_WORD[layout[2]], gfx_arch=self.arch, **_ALGO, - ).name) + # name that drove codegen + the force-include build. The kernel name keys + # off the input dtype (dtype_a), not the C/acc dtype. + self.assertEqual(runner.kernel_name, _config(dtype, layout, self.arch).name) result = runner.run(A, B, problem) self.assertTrue( @@ -179,8 +228,8 @@ def _run_case(self, dtype, layout, shape): f"{dtype}/{layout} {shape[0]} run failed (status {result.status})", ) - ref = _emulate(_emulate(A, dtype) @ _emulate(B, dtype), dtype) - max_rel = _max_rel(result.output, ref) + ref = _reference(A, B, dtype) + max_rel = _max_rel(result.output.astype(np.float64), ref.astype(np.float64)) self.assertLessEqual( max_rel, _TOL[dtype], f"{dtype}/{layout} {shape[0]} max_rel={max_rel:.2e} > {_TOL[dtype]:.0e}", @@ -237,14 +286,13 @@ def _main() -> int: for sname, M, N, K in _SHAPES: total += 1 problem = GemmProblem(M=M, N=N, K=K) - A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) - B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + A, B = _make_inputs(dtype, M, N, K, rng) result = runner.run(A, B, problem) if not result.success: print(f" {tag:<12} {sname:<12} {'RUN FAILED':>9} status={result.status}") continue - ref = _emulate(_emulate(A, dtype) @ _emulate(B, dtype), dtype) - mr = _max_rel(result.output, ref) + ref = _reference(A, B, dtype) + mr = _max_rel(result.output.astype(np.float64), ref.astype(np.float64)) ok = mr <= _TOL[dtype] passed += ok print(f" {tag:<12} {sname:<12} {result.tflops:>9.1f} " diff --git a/projects/composablekernel/dispatcher/tests/test_gemm_utils.py b/projects/composablekernel/dispatcher/tests/test_gemm_utils.py index 34e07ecfcb04..8ed188932e0b 100644 --- a/projects/composablekernel/dispatcher/tests/test_gemm_utils.py +++ b/projects/composablekernel/dispatcher/tests/test_gemm_utils.py @@ -8,6 +8,9 @@ Locks in the bit-level helpers that the TE -> Dispatcher GEMM bridge relies on: * bf16 <-> uint16 encoding (round-to-nearest-even), since numpy has no native bf16 and the runner carries bf16 as a uint16 bit pattern. + * fp8 (E4M3) / bf8 (E5M2) FNUZ <-> uint8 encoding, used for the gfx942 8-bit + float surface. The decode must be exact to the device format; the encode + only needs to land on the nearest representable byte. * dtype / layout parsing from the compiled kernel name, which drives how the runner lays out host buffers. @@ -29,6 +32,12 @@ GemmKernelConfig, _fp32_to_bf16_u16, _bf16_u16_to_fp32, + _fp32_to_fp8_u8, + _fp8_u8_to_fp32, + _fp32_to_bf8_u8, + _bf8_u8_to_fp32, + _fnuz_decode_table, + _output_dtype, _dtype_from_kernel_name, _layout_from_kernel_name, ) @@ -78,6 +87,70 @@ def test_dtype_and_size(self): self.assertEqual(u16.itemsize, 2) # must match sizeof(bf16_t) on device +class TestFp8Bf8Encoding(unittest.TestCase): + """fp8 E4M3 / bf8 E5M2 in the FNUZ format used by gfx942. + + The decode is the load-bearing half (it must equal the device value for a + byte); the encode must land on the nearest representable byte and saturate. + """ + + def test_format_ranges(self): + # FNUZ maxima: E4M3 -> 2^7 * 1.875 = 240; E5M2 -> 2^15 * 1.75 = 57344. + t43 = _fnuz_decode_table(4, 3) + t52 = _fnuz_decode_table(5, 2) + self.assertEqual(float(np.nanmax(t43)), 240.0) + self.assertEqual(float(np.nanmin(t43)), -240.0) + self.assertEqual(float(np.nanmax(t52)), 57344.0) + self.assertEqual(float(np.nanmin(t52)), -57344.0) + + def test_zero_and_nan_slots(self): + # 0x00 is +0; the negative-zero slot 0x80 is the lone NaN (FNUZ). + for tab in (_fnuz_decode_table(4, 3), _fnuz_decode_table(5, 2)): + self.assertEqual(float(tab[0x00]), 0.0) + self.assertTrue(np.isnan(tab[0x80])) + + def test_exactly_representable_roundtrip(self): + exact = np.array([0.0, 0.5, 1.0, -1.0, 2.0, -2.0, 1.5, -0.25, 4.0, 8.0], + dtype=np.float32) + np.testing.assert_array_equal( + _fp8_u8_to_fp32(_fp32_to_fp8_u8(exact)), exact) + np.testing.assert_array_equal( + _bf8_u8_to_fp32(_fp32_to_bf8_u8(exact)), exact) + + def test_decode_is_consistent_with_encode(self): + # The parity contract: ref multiplies decode(encode(x)), so the pair must + # be self-consistent and every encoded byte must decode finite. + rng = np.random.default_rng(1) + x = (rng.standard_normal(5000) * 0.1).astype(np.float32) + for enc, dec in ((_fp32_to_fp8_u8, _fp8_u8_to_fp32), + (_fp32_to_bf8_u8, _bf8_u8_to_fp32)): + d = dec(enc(x)) + self.assertTrue(np.all(np.isfinite(d))) + + def test_saturates_no_inf(self): + # FNUZ has no infinity: huge magnitudes clamp to the finite max. + big = np.array([1e30, -1e30], dtype=np.float32) + self.assertEqual(float(_fp8_u8_to_fp32(_fp32_to_fp8_u8(big))[0]), 240.0) + self.assertEqual(float(_bf8_u8_to_fp32(_fp32_to_bf8_u8(big))[1]), -57344.0) + + def test_dtype_and_size(self): + for enc in (_fp32_to_fp8_u8, _fp32_to_bf8_u8): + u8 = enc(np.zeros(4, dtype=np.float32)) + self.assertEqual(u8.dtype, np.uint8) + self.assertEqual(u8.itemsize, 1) # must match sizeof(fp8_t/bf8_t) + + +class TestOutputDtype(unittest.TestCase): + """Output (C) element dtype must mirror the codegen's get_output_dtype.""" + + def test_mapping(self): + self.assertEqual(_output_dtype("fp16"), "fp16") + self.assertEqual(_output_dtype("bf16"), "bf16") + self.assertEqual(_output_dtype("fp8"), "fp16") + self.assertEqual(_output_dtype("bf8"), "fp16") + self.assertEqual(_output_dtype("int8"), "int32") + + class TestKernelNameParsing(unittest.TestCase): """The runner reads dtype + layout straight from the compiled .so name.""" @@ -114,13 +187,14 @@ class TestConfigNameContract(unittest.TestCase): codegen -> runtime; parsing it back must recover dtype and layout.""" def test_name_roundtrips_through_parsers(self): - for dtype in ("fp16", "bf16"): + for dtype in ("fp16", "bf16", "fp8", "bf8", "int8"): for la, lb, lc in (("row", "col", "row"), ("row", "row", "row"), ("col", "col", "row"), ("col", "row", "row")): cfg = GemmKernelConfig( - dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + dtype_a=dtype, dtype_b=dtype, dtype_c=_output_dtype(dtype), + dtype_acc=("int32" if dtype == "int8" else "fp32"), layout_a=la, layout_b=lb, layout_c=lc, ) name = cfg.name From e3bde1bda5cf32f20bd001611a07fbb707363854 Mon Sep 17 00:00:00 2001 From: muozturk Date: Wed, 1 Jul 2026 13:55:48 -0400 Subject: [PATCH 16/16] [CK_TILE] fix fp8 ColMajor-A/4x1x1 perf regression: match Old-TE cshuffle epilogue unified_gemm_codegen forced CShuffleEpilogueProblem trailing template args (false, 1, 1, DoubleSmemBuffer) for the gemm_universal cshuffle epilogue, while Old-TE's gemm_universal_instance_builder stops at NumWaveGroups (letting the epilogue defaults apply). For RowMajor-A those forced values equal the defaults (parity), but for ColMajor-A and 4x1x1 block-maps they yield a higher-VGPR kernel (120/128 vs Old-TE 92/100) -> lower occupancy -> 30-75% slower. Drop the forced args so the bridge emits the same epilogue as Old-TE. On MI300X all 18 affected fp8 stems recover: 50/51 formerly >15% (M/N/K sweep) rows now within +/-15% (median 0.01%). multi_d variant left unchanged. --- .../composablekernel/dispatcher/codegen/unified_gemm_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py index e40abbd1097f..6ddd3780788f 100755 --- a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py +++ b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py @@ -786,7 +786,7 @@ def _epilogue_code(self, config: KernelConfig) -> str: tuple<>, CLayout, element_wise::PassThrough, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, - TransposeC, NumWaveGroups, false, 1, 1, DoubleSmemBuffer>; + TransposeC, NumWaveGroups>; using GemmEpilogue = CShuffleEpilogue;""" else: return """