diff --git a/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp b/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp index 85c0c2f2c13a..c415ae6c0cef 100644 --- a/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp +++ b/projects/composablekernel/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include "ck_tile/dispatcher/dispatcher.hpp" #include "ck_tile/dispatcher/registry.hpp" @@ -65,15 +66,74 @@ int dispatcher_initialize() return 0; // Already initialized } - // Create kernel key from the force-included kernel header + // Create kernel key from the force-included kernel header. + // + // The GEMM_KEY_* macros are emitted by the codegen into the force-included + // header (see unified_gemm_codegen.py, CK_TILE_SINGLE_KERNEL_INCLUDE block). + // Building the key from them makes the registry entry truthful: it reflects + // THIS kernel's real dtypes/layouts/tile/traits instead of a hard-coded + // fp16/rcr/128x128x32 default. Enum fields use the string_to_* helpers from + // kernel_key.hpp, whose accepted strings match the codegen's emitted values + // byte-for-byte. KernelKey key; +#ifdef GEMM_KEY_DTYPE_A + key.signature.dtype_a = string_to_dtype(GEMM_KEY_DTYPE_A); + key.signature.dtype_b = string_to_dtype(GEMM_KEY_DTYPE_B); + key.signature.dtype_c = string_to_dtype(GEMM_KEY_DTYPE_C); + key.signature.dtype_acc = string_to_dtype(GEMM_KEY_DTYPE_ACC); + key.signature.layout_a = string_to_layout(GEMM_KEY_LAYOUT_A); + key.signature.layout_b = string_to_layout(GEMM_KEY_LAYOUT_B); + key.signature.layout_c = string_to_layout(GEMM_KEY_LAYOUT_C); + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = (GEMM_KEY_GROUPED != 0); + key.signature.split_k = GEMM_KEY_SPLIT_K; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {GEMM_KEY_TILE_M, GEMM_KEY_TILE_N, GEMM_KEY_TILE_K}; + key.algorithm.wave_shape = {GEMM_KEY_WAVE_M, GEMM_KEY_WAVE_N, GEMM_KEY_WAVE_K}; + key.algorithm.warp_tile_shape = {GEMM_KEY_WARP_TILE_M, GEMM_KEY_WARP_TILE_N, GEMM_KEY_WARP_TILE_K}; + key.algorithm.pipeline = string_to_pipeline(GEMM_KEY_PIPELINE); + key.algorithm.scheduler = string_to_scheduler(GEMM_KEY_SCHEDULER); + key.algorithm.epilogue = string_to_epilogue(GEMM_KEY_EPILOGUE); + key.algorithm.block_size = GEMM_KEY_BLOCK_SIZE; + key.algorithm.double_buffer = (GEMM_KEY_DOUBLE_BUFFER != 0); + key.algorithm.persistent = (GEMM_KEY_PERSISTENT != 0); + key.algorithm.preshuffle = (GEMM_KEY_PRESHUFFLE != 0); + key.algorithm.transpose_c = (GEMM_KEY_TRANSPOSE_C != 0); + key.algorithm.num_wave_groups = GEMM_KEY_NUM_WAVE_GROUPS; + // pad_m/n/k participate in both the key's hash/equality and the kernel + // name, so they must be derived from the codegen macros too -- otherwise a + // kernel built with padding disabled would register under a key claiming + // pad=true and disagree with its own name. + key.algorithm.pad_m = (GEMM_KEY_PAD_M != 0); + key.algorithm.pad_n = (GEMM_KEY_PAD_N != 0); + key.algorithm.pad_k = (GEMM_KEY_PAD_K != 0); + key.gfx_arch = GFX_ARCH; +#else + // Fallback default for headers generated before GEMM_KEY_* macros existed + // (fp16 / rcr / compv4-cshuffle-intrawave, 128x128x32). The macro path + // above is the source of truth for any freshly generated kernel. key.signature.dtype_a = DataType::FP16; key.signature.dtype_b = DataType::FP16; key.signature.dtype_c = DataType::FP16; key.signature.dtype_acc = DataType::FP32; - key.signature.layout_a = LayoutTag::RowMajor; - key.signature.layout_b = LayoutTag::ColMajor; - key.signature.layout_c = LayoutTag::RowMajor; + // Derive A/B/C layouts from the force-included kernel's own layout types + // instead of hardcoding rcr. The dispatcher's supports() gate is layout-aware + // (it only constrains a dimension that an operand's inner axis maps to), so a + // wrong key layout makes it reject valid problems -- e.g. a crr kernel does not + // gate K, but with a hardcoded rcr key supports() would apply rcr's K-gate and + // reject TileK=192 problems that Old-TE runs. ALayout/BLayout/CLayout are the + // global aliases exported by the kernel header under CK_TILE_SINGLE_KERNEL_INCLUDE. + using RowMajorLayout = ck_tile::tensor_layout::gemm::RowMajor; + key.signature.layout_a = + std::is_same_v ? LayoutTag::RowMajor : LayoutTag::ColMajor; + key.signature.layout_b = + std::is_same_v ? LayoutTag::RowMajor : LayoutTag::ColMajor; + key.signature.layout_c = + std::is_same_v ? LayoutTag::RowMajor : LayoutTag::ColMajor; key.signature.transpose_a = false; key.signature.transpose_b = false; key.signature.grouped = false; @@ -95,6 +155,7 @@ int dispatcher_initialize() key.algorithm.transpose_c = false; key.algorithm.num_wave_groups = 1; key.gfx_arch = GFX_ARCH; +#endif // GEMM_KEY_DTYPE_A // Register kernel using types from force-included header auto kernel = @@ -310,10 +371,40 @@ int dispatcher_run_gemm( } /** - * Get kernel information + * Get kernel information (legacy single-kernel ABI). + * + * Returns the compile-time KERNEL_NAME of the force-included kernel header. + * Kept for backward compatibility with one-kernel-per-.so callers. */ const char* dispatcher_get_kernel_name() { return KERNEL_NAME; } +/** + * Get the name of the kernel at a given registry index (multi-kernel ABI). + * + * Mirrors the conv/fmha ctypes libs: copies the index-th registered kernel's + * name into the caller-provided buffer so one .so can report a whole batch and + * be selected by name at runtime. Returns 0 on success, -1 on bad args or + * out-of-range index. + */ +int dispatcher_get_kernel_name_at(int index, char* buffer, int buffer_size) +{ + if(!buffer || buffer_size <= 0) + { + return -1; + } + + auto kernels = Registry::instance().get_all(); + if(index < 0 || index >= static_cast(kernels.size())) + { + return -1; + } + + std::string name = kernels[index]->get_name(); + std::strncpy(buffer, name.c_str(), static_cast(buffer_size) - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + /** * Initialize dispatcher (alias) */ diff --git a/projects/composablekernel/dispatcher/codegen/codegen_common.py b/projects/composablekernel/dispatcher/codegen/codegen_common.py index a0486da66d74..a5f022021f8c 100644 --- a/projects/composablekernel/dispatcher/codegen/codegen_common.py +++ b/projects/composablekernel/dispatcher/codegen/codegen_common.py @@ -118,6 +118,7 @@ class CommonTypeMappings: "fp8": "fp8_t", "bf8": "bf8_t", "int8": "int8_t", + "int32": "int32_t", } DTYPE_TO_CK_QUALIFIED = { @@ -127,6 +128,7 @@ class CommonTypeMappings: "fp8": "ck_tile::fp8_t", "bf8": "ck_tile::bf8_t", "int8": "int8_t", + "int32": "int32_t", } DTYPE_TO_DISPATCHER = { @@ -136,6 +138,7 @@ class CommonTypeMappings: "fp8": "DataType::FP8", "bf8": "DataType::BF8", "int8": "DataType::INT8", + "int32": "DataType::INT32", } # GEMM-specific layout mappings ("r"/"c" for row/column major). @@ -202,8 +205,26 @@ class CommonTypeMappings: @staticmethod def get_output_dtype(dtype: str) -> str: - """Get output datatype (fp8/bf8 -> fp16).""" - return "fp16" if dtype in ("fp8", "bf8") else dtype + """Get output (C) datatype for an A/B element dtype. + + Low-precision float inputs accumulate into and store as fp16 + (fp8/bf8 -> fp16); int8 stores its int32 accumulator (int8 -> int32). + Everything else stores in its own dtype. + """ + if dtype in ("fp8", "bf8"): + return "fp16" + if dtype == "int8": + return "int32" + return dtype + + @staticmethod + def get_acc_dtype(dtype: str) -> str: + """Get accumulator datatype for an A/B element dtype. + + Integer GEMM accumulates in int32; every float dtype accumulates in + fp32. + """ + return "int32" if dtype == "int8" else "fp32" # ============================================================================ diff --git a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py index c0fb08aa4436..6ddd3780788f 100755 --- a/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py +++ b/projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py @@ -187,6 +187,10 @@ def is_preshuffle_config_valid( log = logging.getLogger(__name__) +def _is_power_of_two(x: int) -> bool: + return x > 0 and (x & (x - 1)) == 0 + + # ============================================================================ # Configuration and Data Structures # ============================================================================ @@ -410,12 +414,13 @@ def _types(self, config: KernelConfig, kernel_name: str) -> str: def _kernel_local_types(self, config: KernelConfig) -> str: """Generate data type and layout definitions inside kernel namespace""" output_dtype = self.tm.get_output_dtype(self.datatype) + acc_dtype = self.tm.get_acc_dtype(self.datatype) return f""" // Data types (inside namespace to avoid conflicts across layouts) using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; - using AccDataType = float; + using AccDataType = {self.tm.DTYPE_TO_CK[acc_dtype]}; using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; // Layouts (inside namespace to avoid conflicts when mixing layouts) @@ -448,6 +453,7 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str t = config.tile tr = config.trait output_dtype = self.tm.get_output_dtype(self.datatype) + acc_dtype = self.tm.get_acc_dtype(self.datatype) # Generate unique struct name and namespace from kernel name struct_name = f"Kernel_{kernel_name}" @@ -463,7 +469,7 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str // Data types (inside namespace to avoid conflicts across different kernels) using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; -using AccDataType = float; +using AccDataType = {self.tm.DTYPE_TO_CK[acc_dtype]}; using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; // Layouts (inside namespace to avoid conflicts when mixing layouts like RCR + RRR) @@ -518,8 +524,45 @@ def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str constexpr const char* KERNEL_NAME = {ns_name}::KERNEL_NAME; using ADataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; -using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]}; -using AccDataType = float; +using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[output_dtype]}; +using AccDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[acc_dtype]}; + +// KernelKey field descriptors for the force-included kernel. +// The ctypes library builds the registry KernelKey from these so the +// registered entry reflects this kernel's real traits (not a hard-coded +// fp16/rcr default). Enum-valued fields are emitted as the exact strings +// consumed by string_to_dtype/layout/pipeline/scheduler/epilogue in +// kernel_key.hpp; shape/flag fields are emitted as numeric/0-1 literals. +#define GEMM_KEY_DTYPE_A "{self.datatype}" +#define GEMM_KEY_DTYPE_B "{self.datatype}" +#define GEMM_KEY_DTYPE_C "{output_dtype}" +#define GEMM_KEY_DTYPE_ACC "{acc_dtype}" +#define GEMM_KEY_LAYOUT_A "{self.layout[0]}" +#define GEMM_KEY_LAYOUT_B "{self.layout[1]}" +#define GEMM_KEY_LAYOUT_C "{self.layout[2]}" +#define GEMM_KEY_PIPELINE "{tr.pipeline}" +#define GEMM_KEY_SCHEDULER "{tr.scheduler}" +#define GEMM_KEY_EPILOGUE "{tr.epilogue}" +#define GEMM_KEY_TILE_M {t.tile_m} +#define GEMM_KEY_TILE_N {t.tile_n} +#define GEMM_KEY_TILE_K {t.tile_k} +#define GEMM_KEY_WAVE_M {t.warp_m} +#define GEMM_KEY_WAVE_N {t.warp_n} +#define GEMM_KEY_WAVE_K {t.warp_k} +#define GEMM_KEY_WARP_TILE_M {t.warp_tile_m} +#define GEMM_KEY_WARP_TILE_N {t.warp_tile_n} +#define GEMM_KEY_WARP_TILE_K {t.warp_tile_k} +#define GEMM_KEY_BLOCK_SIZE {config.block_size} +#define GEMM_KEY_NUM_WAVE_GROUPS {config.num_wave_groups} +#define GEMM_KEY_PAD_M {int(tr.pad_m)} +#define GEMM_KEY_PAD_N {int(tr.pad_n)} +#define GEMM_KEY_PAD_K {int(tr.pad_k)} +#define GEMM_KEY_PERSISTENT {int(tr.persistent)} +#define GEMM_KEY_DOUBLE_BUFFER {int(tr.pipeline == "compv4" or tr.pipeline == "preshufflev2")} +#define GEMM_KEY_PRESHUFFLE {int(config.preshuffle)} +#define GEMM_KEY_TRANSPOSE_C 0 +#define GEMM_KEY_GROUPED 0 +#define GEMM_KEY_SPLIT_K 1 #endif // CK_TILE_SINGLE_KERNEL_INCLUDE """ @@ -743,7 +786,7 @@ def _epilogue_code(self, config: KernelConfig) -> str: tuple<>, CLayout, element_wise::PassThrough, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, - TransposeC, NumWaveGroups, false, 1, 1, DoubleSmemBuffer>; + TransposeC, NumWaveGroups>; using GemmEpilogue = CShuffleEpilogue;""" else: return """ @@ -774,6 +817,7 @@ def generate( """Generate dispatcher wrapper""" kernel_name = KernelNaming.generate(config, self.datatype, self.layout) output_dtype = self.tm.get_output_dtype(self.datatype) + acc_dtype = self.tm.get_acc_dtype(self.datatype) rel_path = kernel_path.relative_to(output_dir) return f"""// SPDX-License-Identifier: MIT @@ -808,7 +852,7 @@ def generate( key.signature.dtype_a = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; key.signature.dtype_b = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; key.signature.dtype_c = {self.tm.DTYPE_TO_DISPATCHER[output_dtype]}; - key.signature.dtype_acc = DataType::FP32; + key.signature.dtype_acc = {self.tm.DTYPE_TO_DISPATCHER[acc_dtype]}; key.signature.layout_a = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[0]]}; key.signature.layout_b = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[1]]}; key.signature.layout_c = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[2]]}; @@ -1030,9 +1074,15 @@ def _get_configs_for_variant(self, variant: GemmVariant) -> List[KernelConfig]: trait_configs = self._get_trait_configs() for tile, trait in itertools.product(tile_configs, trait_configs): - # Perform variant-specific architecture validation + # Perform variant-specific architecture validation against the + # trait's ACTUAL pipeline/scheduler (not a hard-coded compv4). if self.arch_filter and HAS_ARCH_FILTER: - if not self._is_tile_arch_valid(tile, variant): + if not self._is_tile_arch_valid( + tile, + variant, + pipeline=trait.pipeline, + scheduler=trait.scheduler, + ): continue if variant == GemmVariant.STANDARD: @@ -1105,9 +1155,37 @@ def _get_tile_configs(self) -> List[TileConfig]: rejected_count += 1 continue - # Architecture-specific validation + # CShuffle-store correctness gate. The CShuffle epilogue stores the + # accumulator back through LDS in power-of-two MRepeat/NRepeat chunks, + # so a tile whose per-wave repeat count -- tile / (warp * warp_tile) -- + # is not a power of two is mis-stored and yields numerically WRONG + # results at runtime. The kernel still compiles (the epilogue's + # static_asserts only check divisibility, which such tiles satisfy), + # so it must be filtered here. Observed on MI350 for tile_m=192 + # (MRepeat = 192 / (2*32) = 3): verified incorrect on BOTH the bridge + # and Tile Engine at every shape, including shapes divisible by 192. + # Power-of-two tiles (64/128/256) are unaffected. + m_repeat = tile.tile_m // (tile.warp_m * tile.warp_tile_m) + n_repeat = tile.tile_n // (tile.warp_n * tile.warp_tile_n) + if not (_is_power_of_two(m_repeat) and _is_power_of_two(n_repeat)): + rejected_count += 1 + continue + + # Architecture-specific validation. This is a pre-filter run before + # tiles are paired with traits, so keep a tile if it is legal under + # ANY configured pipeline/scheduler; the precise per-trait check + # happens later in _get_configs_for_variant. Filtering here with a + # single hard-coded pipeline (compv4) wrongly dropped tiles that are + # legal under mem/compv3. if self.arch_filter and HAS_ARCH_FILTER: - if not self._is_tile_arch_valid(tile): + trait_cfg = self.config.get("trait_config", {}) + pipelines = trait_cfg.get("pipeline") or ["compv4"] + schedulers = trait_cfg.get("scheduler") or ["intrawave"] + if not any( + self._is_tile_arch_valid(tile, pipeline=pl, scheduler=sc) + for pl in pipelines + for sc in schedulers + ): rejected_count += 1 continue @@ -1119,13 +1197,23 @@ def _get_tile_configs(self) -> List[TileConfig]: return configs def _is_tile_arch_valid( - self, tile: TileConfig, variant: GemmVariant = None + self, + tile: TileConfig, + variant: GemmVariant = None, + pipeline: str = None, + scheduler: str = None, ) -> bool: """Check if tile configuration is valid for target architecture Args: tile: Tile configuration to validate variant: GEMM variant (affects operator-specific constraints) + pipeline: Trait pipeline to validate against. Pass the config's + actual pipeline -- omitting it falls back to ``compv4``, whose + MFMA constraints are stricter than ``mem``/``compv3`` and would + wrongly reject tiles that are legal under those pipelines. + scheduler: Trait scheduler to validate against (defaults to + ``intrawave`` for the same reason). """ if not self.arch_filter or not HAS_ARCH_FILTER: return True @@ -1146,8 +1234,10 @@ def _is_tile_arch_valid( # Map GEMM variant to operator type for validation operator = None - pipeline = "compv4" # Default - scheduler = "intrawave" # Default + if pipeline is None: + pipeline = "compv4" # Default (representative compute pipeline) + if scheduler is None: + scheduler = "intrawave" # Default if OperatorType is not None and variant is not None: variant_to_operator = { diff --git a/projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py b/projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py new file mode 100644 index 000000000000..53a44060ef35 --- /dev/null +++ b/projects/composablekernel/dispatcher/examples/gemm/python/12_te_bridge.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 12: Tile Engine -> Dispatcher bridge (gallery) + +Unlike examples 01-11 (which drive the Dispatcher's native ctypes Registry), +this example exercises the *Tile Engine -> Dispatcher bridge* in +``dispatcher/python/gemm_utils.py``. The bridge is the path the Tile Engine +itself uses: one common ``GemmKernelConfig`` feeds codegen, force-include +compile, and a flat extern "C" ABI, and ``GpuGemmRunner`` runs the resulting +.so against a NumPy reference. + +It is a small gallery of three demos that together cover the surface the bridge +gained over its original fp16/rcr-only slice: + + matrix every (dtype, layout) pair the universal GEMM supports -- fp16 and + bf16 across the four row/col A/B combinations (row-major C only). + shapes why padding matters: a padded kernel accepts an awkward, non-tile- + aligned problem (M, N not divisible by the tile) while the equivalent + no-pad kernel rejects it -- the same selection rule the Tile Engine + sees when it sweeps pad on/off. + sweep the "search space" idea: one fixed signature, several *algorithms* + (tile / wave / pipeline), built and ranked by measured TFLOPS -- a + miniature of what the Tile Engine driver does at scale. + +(Kernel *variants* such as Stream-K and grouped GEMM ride the same bridge but +live on separate branches; this example stays within the regular GEMM stack.) + +Usage: + python3 12_te_bridge.py # runs all three demos + python3 12_te_bridge.py --demo matrix + python3 12_te_bridge.py --demo shapes + python3 12_te_bridge.py --demo sweep + python3 12_te_bridge.py --size 1024 --rtol 2e-2 --arch gfx950 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np # noqa: E402 + +from gemm_utils import ( # noqa: E402 + GemmKernelConfig, + GemmProblem, + GpuGemmRunner, + setup_multiple_gemm_dispatchers, +) +from ctypes_utils import detect_gpu_arch # noqa: E402 + +# A single algorithm known to compile and run on gfx942. Only the Signature +# (dtype + layout) varies in the matrix demo; the Algorithm is held fixed so the +# demo isolates the bridge's dtype/layout generality. +_ALGO = dict( + tile_m=64, tile_n=64, tile_k=64, + wave_m=4, wave_n=1, wave_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16, + pipeline="compv3", scheduler="intrawave", epilogue="cshuffle", + pad_m=False, pad_n=False, pad_k=False, +) + +# (dtype, layout) pairs. Column-major C (e.g. rcc) is rejected at build by the +# universal GEMM, so every case keeps row-major C -- which leaves exactly four +# A/B combinations (rcr/rrr/ccr/crr). Both dtypes cover all four. +_CASES = [ + ("fp16", "rcr"), ("fp16", "rrr"), ("fp16", "ccr"), ("fp16", "crr"), + ("bf16", "rcr"), ("bf16", "rrr"), ("bf16", "ccr"), ("bf16", "crr"), +] + +_LAYOUT_WORD = {"r": "row", "c": "col"} + + +def _emulate(x: np.ndarray, dtype: str) -> np.ndarray: + """Round fp32 inputs to the kernel's storage dtype so the CPU reference + matches what the GPU actually multiplies.""" + if dtype == "bf16": + u32 = np.ascontiguousarray(x, dtype=np.float32).view(np.uint32) + rounded = (u32 + ((u32 >> 16) & 1) + np.uint32(0x7FFF)) >> 16 + return (rounded.astype(np.uint32) << 16).view(np.float32) + return x.astype(np.float16).astype(np.float32) + + +def _max_rel(out: np.ndarray, ref: np.ndarray) -> float: + # Global relative error (normalize by the largest reference magnitude): + # per-element ratios explode on the near-zero entries that K-length + # accumulation of zero-mean data produces, so they are not meaningful. + denom = float(np.max(np.abs(ref))) + 1e-12 + return float(np.max(np.abs(out - ref))) / denom + + +def _config(dtype: str, layout: str, arch: str, **algo) -> GemmKernelConfig: + la, lb, lc = layout + return GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, dtype_c=dtype, + layout_a=_LAYOUT_WORD[la], layout_b=_LAYOUT_WORD[lb], layout_c=_LAYOUT_WORD[lc], + gfx_arch=arch, **(algo or _ALGO), + ) + + +def _reference(A, B, dtype): + # Emulate both input quantization (A,B stored as dtype) and the output store + # (GPU writes C back as dtype_c), so round on both ends before comparing. + return _emulate(_emulate(A, dtype) @ _emulate(B, dtype), dtype) + + +# --------------------------------------------------------------------------- +# Demo 1: dtype x layout matrix +# --------------------------------------------------------------------------- +def demo_matrix(size, rtol, arch): + print(f"\n[matrix] dtype x layout, M=N=K={size}, rtol={rtol:g}") + problem = GemmProblem(M=size, N=size, K=size) + configs = [_config(dt, lay, arch) for dt, lay in _CASES] + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + + rng = np.random.default_rng(42) + A = (rng.standard_normal((problem.M, problem.K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((problem.K, problem.N)) * 0.1).astype(np.float32) + + n_pass = 0 + for (dtype, layout), so in zip(_CASES, so_paths): + tag = f"{dtype}/{layout}" + if so is None: + print(f" {tag:10s} BUILD FAILED") + continue + result = GpuGemmRunner(lib_path=so).run(A, B, problem) + if not result.success: + print(f" {tag:10s} RUN FAILED (status {result.status})") + continue + mr = _max_rel(result.output, _reference(A, B, dtype)) + ok = mr <= rtol + n_pass += ok + print(f" {tag:10s} tflops={result.tflops:7.1f} max_rel={mr:.2e} " + f"{'PASS' if ok else 'FAIL'}") + print(f" -> {n_pass}/{len(_CASES)} passed") + return n_pass, len(_CASES) + + +# --------------------------------------------------------------------------- +# Demo 2: padding vs an awkward (non-tile-aligned) shape +# --------------------------------------------------------------------------- +def demo_shapes(rtol, arch): + print("\n[shapes] padding lets a kernel accept a non-tile-aligned problem") + # 128-tile kernels; awkward M, N do not divide 128 (K stays divisible by 8 + # for the fp16 vectorized reduction load). + algo_pad = dict( + tile_m=128, tile_n=128, tile_k=32, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave", epilogue="cshuffle", + pad_m=True, pad_n=True, pad_k=True, + ) + algo_nopad = dict(algo_pad, pad_m=False, pad_n=False, pad_k=False) + + cfg_pad = _config("fp16", "rcr", arch, **algo_pad) + cfg_nopad = _config("fp16", "rcr", arch, **algo_nopad) + so_pad, so_nopad = setup_multiple_gemm_dispatchers([cfg_pad, cfg_nopad], verbose=False) + + M, N, K = 257, 129, 512 # awkward: 257, 129 not divisible by 128 + problem = GemmProblem(M=M, N=N, K=K) + rng = np.random.default_rng(7) + A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + ref = _reference(A, B, "fp16") + + print(f" awkward problem M={M} N={N} K={K} (M,N not divisible by tile 128)") + n_pass = 0 + expectations = [("padded", so_pad, True), ("no-pad", so_nopad, False)] + for label, so, should_pass in expectations: + if so is None: + print(f" {label:8s} BUILD FAILED") + continue + result = GpuGemmRunner(lib_path=so).run(A, B, problem) + if result.success: + mr = _max_rel(result.output, ref) + accepted = mr <= rtol + outcome = f"ACCEPTED tflops={result.tflops:7.1f} max_rel={mr:.2e}" + else: + accepted = False + # status -2 == select_kernel found no kernel whose tiling fits. + outcome = f"REJECTED (status {result.status})" + # "Correct" = the no-pad kernel rejects and the padded one accepts. + correct = accepted == should_pass + n_pass += correct + print(f" {label:8s} {outcome:42s} {'as expected' if correct else 'UNEXPECTED'}") + print(f" -> {n_pass}/2 behaved as expected (padded accepts, no-pad rejects)") + return n_pass, 2 + + +# --------------------------------------------------------------------------- +# Demo 3: algorithm sweep over one fixed signature +# --------------------------------------------------------------------------- +def demo_sweep(size, rtol, arch): + print(f"\n[sweep] fixed fp16/rcr signature, several algorithms, M=N=K={size}") + # A handful of distinct algorithms (the Tile Engine sweeps thousands of + # these). Each is a different tile / wave / pipeline point in the search + # space; padding is on so any size is accepted. + algos = [ + dict(tile_m=128, tile_n=128, tile_k=32, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, pipeline="compv4"), + dict(tile_m=256, tile_n=128, tile_k=32, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, pipeline="compv4"), + dict(tile_m=128, tile_n=128, tile_k=64, wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, pipeline="compv3"), + dict(tile_m=64, tile_n=64, tile_k=64, wave_m=4, wave_n=1, wave_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16, pipeline="compv3"), + ] + common = dict(scheduler="intrawave", epilogue="cshuffle", + pad_m=True, pad_n=True, pad_k=True) + configs = [_config("fp16", "rcr", arch, **dict(a, **common)) for a in algos] + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + + problem = GemmProblem(M=size, N=size, K=size) + rng = np.random.default_rng(123) + A = (rng.standard_normal((size, size)) * 0.1).astype(np.float32) + B = (rng.standard_normal((size, size)) * 0.1).astype(np.float32) + ref = _reference(A, B, "fp16") + + rows = [] + for cfg, so in zip(configs, so_paths): + label = f"{cfg.tile_m}x{cfg.tile_n}x{cfg.tile_k}/{cfg.pipeline}" + if so is None: + rows.append((label, None, None)) + continue + result = GpuGemmRunner(lib_path=so).run(A, B, problem) + if not result.success: + rows.append((label, None, None)) + continue + rows.append((label, result.tflops, _max_rel(result.output, ref))) + + ranked = sorted((r for r in rows if r[1] is not None), + key=lambda r: r[1], reverse=True) + print(f" {'rank':>4} {'algorithm':<24} {'tflops':>9} {'max_rel':>10}") + for i, (label, tflops, mr) in enumerate(ranked, 1): + print(f" {i:>4} {label:<24} {tflops:>9.1f} {mr:>10.2e}") + for label, tflops, _ in rows: + if tflops is None: + print(f" {'-':>4} {label:<24} {'BUILD/RUN FAILED':>20}") + if ranked: + print(f" -> fastest: {ranked[0][0]} at {ranked[0][1]:.1f} TFLOPS") + n_ok = sum(1 for _, _, mr in ranked if mr <= rtol) + return n_ok, len(configs) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Tile Engine -> Dispatcher bridge example (gallery)", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--demo", choices=["matrix", "shapes", "sweep", "all"], + default="all", help="which demo to run (default: all)") + parser.add_argument("--size", type=int, default=512, help="M=N=K (default 512)") + parser.add_argument("--rtol", type=float, default=2e-2, + help="relative tolerance (default 2e-2)") + parser.add_argument("--arch", default=detect_gpu_arch(), + help="GPU target arch (default: auto-detected via rocminfo)") + args = parser.parse_args() + + demos = ["matrix", "shapes", "sweep"] if args.demo == "all" else [args.demo] + total_pass = 0 + total = 0 + for d in demos: + if d == "matrix": + p, t = demo_matrix(args.size, args.rtol, args.arch) + elif d == "shapes": + p, t = demo_shapes(args.rtol, args.arch) + else: + p, t = demo_sweep(args.size, args.rtol, args.arch) + total_pass += p + total += t + + print(f"\n{total_pass}/{total} checks passed across {len(demos)} demo(s)") + return 0 if total_pass == total else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp index be22d94b3331..79e619105a21 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace ck_tile { namespace dispatcher { @@ -50,26 +52,46 @@ class GeneratedTileKernelInstance : public KernelInstance bool supports(const Problem& problem) const override { - // Check dimension divisibility if padding not enabled + // Tile-divisibility gate, mirroring ck_tile::GemmKernel::IsSupportedArgument + // exactly. A dimension only needs to be a multiple of its tile size when an + // operand whose contiguous (inner) axis is that dimension participates AND + // padding for it is disabled. This is layout-dependent: + // + // layout RowMajor A -> inner axis K | layout ColMajor A -> inner axis M + // layout RowMajor B -> inner axis N | layout ColMajor B -> inner axis K + // layout RowMajor C -> inner axis N | layout ColMajor C -> inner axis M + // + // The old check blindly required M % TileM == 0 for every layout, which + // wrongly rejected e.g. rcr kernels (RowMajor A & C never gate M) on + // M-indivisible problems that Old-TE runs fine. Anything this lets through + // is still validated by the kernel's own IsSupportedArgument inside launch(), + // so the bridge stays a strict functional equivalent of Old-TE. constexpr bool pad_m = SelectedKernel::kPadM; constexpr bool pad_n = SelectedKernel::kPadN; constexpr bool pad_k = SelectedKernel::kPadK; - if(pad_m && pad_n && pad_k) - { - return true; // Padding enabled - supports any size - } - - // Check divisibility constexpr int tile_m = SelectedKernel::TileM; constexpr int tile_n = SelectedKernel::TileN; constexpr int tile_k = SelectedKernel::TileK; - if(!pad_m && problem.M % tile_m != 0) + const auto is_row = [](LayoutTag l) { return l == LayoutTag::RowMajor; }; + const bool row_a = is_row(key_.signature.layout_a); + const bool row_b = is_row(key_.signature.layout_b); + const bool row_c = is_row(key_.signature.layout_c); + + // Which problem dimensions are actually constrained for this layout combo. + const bool require_m = (!row_a) || (!row_c); // ColMajor A or C gate M + const bool require_n = row_b || row_c; // RowMajor B or C gate N + const bool require_k = row_a || (!row_b); // RowMajor A or ColMajor B gate K + + const std::int64_t k_grain = + static_cast(tile_k) * (problem.k_batch > 0 ? problem.k_batch : 1); + + if(require_m && !pad_m && problem.M % tile_m != 0) return false; - if(!pad_n && problem.N % tile_n != 0) + if(require_n && !pad_n && problem.N % tile_n != 0) return false; - if(!pad_k && problem.K % tile_k != 0) + if(require_k && !pad_k && problem.K % k_grain != 0) return false; return true; @@ -106,11 +128,11 @@ class GeneratedTileKernelInstance : public KernelInstance stream_cfg.stream_id_ = reinterpret_cast(stream); stream_cfg.time_kernel_ = bench; stream_cfg.log_level_ = 0; - stream_cfg.cold_niters_ = bench ? 5 : 0; - stream_cfg.nrepeat_ = bench ? 10 : 1; + stream_cfg.cold_niters_ = bench ? env_int("CK_TILE_BENCH_WARMUP", 50) : 0; + stream_cfg.nrepeat_ = bench ? env_int("CK_TILE_BENCH_REPEAT", 100) : 1; stream_cfg.is_gpu_timer_ = bench; - stream_cfg.flush_cache_ = false; - stream_cfg.rotating_count_ = 1; + stream_cfg.flush_cache_ = bench && env_bool("CK_TILE_BENCH_FLUSH", true); + stream_cfg.rotating_count_ = bench ? env_int("CK_TILE_BENCH_ROTATING", 1000) : 1; // Call the generated kernel's launch method return SelectedKernel::launch(args, stream_cfg); @@ -134,6 +156,33 @@ class GeneratedTileKernelInstance : public KernelInstance } private: + // Read an integer benchmark knob from the environment, falling back to + // `fallback` when unset or unparseable. + static int env_int(const char* name, int fallback) + { + const char* v = std::getenv(name); + if(v == nullptr || *v == '\0') + return fallback; + char* end = nullptr; + const long out = std::strtol(v, &end, 10); + if(end == v) + return fallback; + return static_cast(out); + } + + // Read a boolean benchmark knob ("0"/"false"/"off", any case => false, else true). + static bool env_bool(const char* name, bool fallback) + { + const char* v = std::getenv(name); + if(v == nullptr || *v == '\0') + return fallback; + std::string s(v); + for(char& c : s) + if(c >= 'A' && c <= 'Z') + c = static_cast(c - 'A' + 'a'); + return !(s == "0" || s == "false" || s == "off"); + } + KernelKey key_; std::string name_; }; diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp index a3a0b0468562..e709c00e153f 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp @@ -29,27 +29,37 @@ class TileKernelInstance : public KernelInstance bool supports(const Problem& problem) const override { - // Check dimension divisibility if padding not enabled + // Tile-divisibility gate, layout-aware to match + // ck_tile::GemmKernel::IsSupportedArgument (see generated_tile_backend.hpp + // for the full rationale). A dimension is only constrained when an operand + // whose inner axis is that dimension participates and its padding is off: + // RowMajor A->K, ColMajor A->M; RowMajor B->N, ColMajor B->K; + // RowMajor C->N, ColMajor C->M. constexpr bool pad_m = SelectedKernel::kPadM; constexpr bool pad_n = SelectedKernel::kPadN; constexpr bool pad_k = SelectedKernel::kPadK; - if(pad_m && pad_n && pad_k) - { - // Padding enabled - supports any size - return true; - } - - // Check divisibility constexpr int tile_m = SelectedKernel::TileM; constexpr int tile_n = SelectedKernel::TileN; constexpr int tile_k = SelectedKernel::TileK; - if(!pad_m && problem.M % tile_m != 0) + const auto is_row = [](LayoutTag l) { return l == LayoutTag::RowMajor; }; + const bool row_a = is_row(key_.signature.layout_a); + const bool row_b = is_row(key_.signature.layout_b); + const bool row_c = is_row(key_.signature.layout_c); + + const bool require_m = (!row_a) || (!row_c); + const bool require_n = row_b || row_c; + const bool require_k = row_a || (!row_b); + + const std::int64_t k_grain = + static_cast(tile_k) * (problem.k_batch > 0 ? problem.k_batch : 1); + + if(require_m && !pad_m && problem.M % tile_m != 0) return false; - if(!pad_n && problem.N % tile_n != 0) + if(require_n && !pad_n && problem.N % tile_n != 0) return false; - if(!pad_k && problem.K % tile_k != 0) + if(require_k && !pad_k && problem.K % k_grain != 0) return false; // Check shared memory budget if specified diff --git a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp index 24b20ecd9b8f..24d67cbf437a 100644 --- a/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp +++ b/projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp @@ -369,6 +369,11 @@ inline Scheduler string_to_scheduler(const std::string& str) { if(str == "auto") return Scheduler::Auto; + // Preshuffle kernels emit "default"; the codegen maps it to Scheduler::Auto + // (see codegen_common.py SCHEDULER_TO_DISPATCHER), so mirror that here + // instead of silently falling through to Intrawave. + if(str == "default") + return Scheduler::Auto; if(str == "intrawave") return Scheduler::Intrawave; if(str == "interwave") diff --git a/projects/composablekernel/dispatcher/python/ctypes_utils.py b/projects/composablekernel/dispatcher/python/ctypes_utils.py index d719d1405e5d..cc94ede685c9 100644 --- a/projects/composablekernel/dispatcher/python/ctypes_utils.py +++ b/projects/composablekernel/dispatcher/python/ctypes_utils.py @@ -1073,7 +1073,7 @@ def _generate_single_kernel_subprocess(args: dict) -> Tuple[bool, Optional[str], "--config", config_file, "--variants", - "standard", + args.get("variant", "standard"), ] res = subprocess.run(cmd, capture_output=True, text=True, timeout=300) diff --git a/projects/composablekernel/dispatcher/python/gemm_utils.py b/projects/composablekernel/dispatcher/python/gemm_utils.py new file mode 100644 index 000000000000..ead6baf648da --- /dev/null +++ b/projects/composablekernel/dispatcher/python/gemm_utils.py @@ -0,0 +1,1021 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +""" +GEMM Tile Engine <-> Dispatcher bridge. + +This is the GEMM counterpart of ``grouped_conv_utils.py`` / ``fmha_utils.py``: +a single shared config dataclass (``GemmKernelConfig``) that Tile Engine imports +and hands back to the dispatcher. There is no translator between two +vocabularies -- both sides share the one object whose ``.name`` mirrors the +kernel identifier baked into the generated kernel header. + +Public surface (mirrors the grouped_conv bridge): + + GemmKernelConfig -- the shared contract dataclass + .name -- registry/runtime lookup key (byte-exact) + .to_codegen_json() -- feeds unified_gemm_codegen.py + GemmProblem -- a single (M, N, K) problem + setup_multiple_gemm_dispatchers -- codegen + hipcc -> .so paths (NO GPU) + GemmDispatcherLib -- thin ctypes ABI wrapper + GpuGemmRunner -- GPU memory + run + time (from a .so path) + expand_sweep -- TE JSON sweep config -> [GemmKernelConfig] + +The heavy lifting for codegen and compilation is reused from ``ctypes_utils`` +so there is a single source of truth for how a kernel header is produced and +how it is compiled into a ``.so``. +""" + +from __future__ import annotations + +import ctypes +import functools +import itertools +import json +import multiprocessing +import os +import shutil +import subprocess +import tempfile +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +# Reuse the proven codegen/compile leaf helpers from the dispatcher's own +# python layer. gemm_utils is a thin bridge on top of these. +import ctypes_utils as _cu + + +# ============================================================================ +# Layout / dtype helpers +# ============================================================================ + +_LAYOUT_CHAR = {"row": "r", "col": "c", "r": "r", "c": "c"} +_LAYOUT_WORD = {"r": "row", "c": "col"} + + +def _cap(flag: bool) -> str: + """Reproduce Python ``str(bool).capitalize()`` -> 'True' / 'False'.""" + return "True" if flag else "False" + + +# ============================================================================ +# The shared contract: GemmKernelConfig +# ============================================================================ + + +@dataclass +class GemmKernelConfig: + """The common config struct shared by Tile Engine and the Dispatcher. + + Naming convention (the "warp/wave trap" lives here, in ONE place): + * ``wave_m/n/k`` -- warps per block (C++ ``wave_shape``; TE "warp"). + * ``warp_tile_m/n/k`` -- MFMA instruction shape (C++ ``warp_tile_shape``; + TE "warp_tile"). + """ + + # --- Signature: what operation is computed ----------------------------- + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + # --- Algorithm: how it is implemented ---------------------------------- + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + pipeline: str = "compv4" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + persistent: bool = False + + gfx_arch: str = "gfx942" + variant: str = "standard" + + # ------------------------------------------------------------------ # + # Derived string fragments + # ------------------------------------------------------------------ # + @property + def layout(self) -> str: + """3-char layout string, e.g. 'rcr'.""" + return ( + _LAYOUT_CHAR[self.layout_a] + + _LAYOUT_CHAR[self.layout_b] + + _LAYOUT_CHAR[self.layout_c] + ) + + @property + def tile_str(self) -> str: + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + @property + def wave_str(self) -> str: + return f"{self.wave_m}x{self.wave_n}x{self.wave_k}" + + @property + def warp_tile_str(self) -> str: + return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}" + + @property + def name(self) -> str: + """Registry / runtime lookup key. + + Reproduces, byte-for-byte, the ``KERNEL_NAME`` that + ``unified_gemm_codegen.py::KernelNaming.generate`` bakes into the + generated kernel header (and that the .so reports via + ``dispatcher_get_kernel_name``). This is the single thread tying + config -> codegen -> runtime together. + """ + name = ( + f"gemm_{self.dtype_a}_{self.layout}" + f"_{self.pipeline}_{self.epilogue}_{self.scheduler}" + f"_{_cap(self.pad_m)}_{_cap(self.pad_n)}_{_cap(self.pad_k)}" + f"_{_cap(self.persistent)}" + f"_{self.tile_str}_{self.wave_str}_{self.warp_tile_str}" + ) + if self.variant == "preshuffle": + name += "_preshuffle" + elif self.variant == "streamk": + name += "_streamk" + return name + + # ------------------------------------------------------------------ # + # Serialization + # ------------------------------------------------------------------ # + def to_codegen_json(self) -> Dict[str, Any]: + """Single-config JSON consumed by unified_gemm_codegen.py. + + Note the warp/wave mapping: the codegen calls the warps-per-block + triple ``warp_*`` and the MFMA triple ``warp_tile_*``. We translate + from dispatcher semantics here so the mapping cannot drift. + """ + return { + "tile_config": { + "tile_m": [self.tile_m], + "tile_n": [self.tile_n], + "tile_k": [self.tile_k], + # dispatcher wave_* -> codegen warp_* (warps per block) + "warp_m": [self.wave_m], + "warp_n": [self.wave_n], + "warp_k": [self.wave_k], + # dispatcher warp_tile_* -> codegen warp_tile_* (MFMA shape) + "warp_tile_m": [self.warp_tile_m], + "warp_tile_n": [self.warp_tile_n], + "warp_tile_k": [self.warp_tile_k], + }, + "trait_config": { + "pipeline": [self.pipeline], + "epilogue": [self.epilogue], + "scheduler": [self.scheduler], + "pad_m": [self.pad_m], + "pad_n": [self.pad_n], + "pad_k": [self.pad_k], + "persistent": [self.persistent], + }, + } + + def to_dict(self) -> Dict[str, Any]: + return { + "dtype_a": self.dtype_a, + "dtype_b": self.dtype_b, + "dtype_c": self.dtype_c, + "dtype_acc": self.dtype_acc, + "layout": self.layout, + "tile": [self.tile_m, self.tile_n, self.tile_k], + "wave": [self.wave_m, self.wave_n, self.wave_k], + "warp_tile": [self.warp_tile_m, self.warp_tile_n, self.warp_tile_k], + "pipeline": self.pipeline, + "scheduler": self.scheduler, + "epilogue": self.epilogue, + "pad": [self.pad_m, self.pad_n, self.pad_k], + "persistent": self.persistent, + "gfx_arch": self.gfx_arch, + "variant": self.variant, + "name": self.name, + } + + def to_ctypes_config(self) -> "_cu.KernelConfig": + """Convert to the ctypes_utils.KernelConfig used by the codegen/validate + helpers. ctypes_utils renames the MFMA triple ``warp_*`` (no _tile).""" + return _cu.KernelConfig( + dtype_a=self.dtype_a, + dtype_b=self.dtype_b, + dtype_c=self.dtype_c, + dtype_acc=self.dtype_acc, + layout_a=_LAYOUT_WORD[_LAYOUT_CHAR[self.layout_a]], + layout_b=_LAYOUT_WORD[_LAYOUT_CHAR[self.layout_b]], + layout_c=_LAYOUT_WORD[_LAYOUT_CHAR[self.layout_c]], + tile_m=self.tile_m, + tile_n=self.tile_n, + tile_k=self.tile_k, + wave_m=self.wave_m, + wave_n=self.wave_n, + wave_k=self.wave_k, + warp_m=self.warp_tile_m, + warp_n=self.warp_tile_n, + warp_k=self.warp_tile_k, + pipeline=self.pipeline, + scheduler=self.scheduler, + epilogue=self.epilogue, + pad_m=self.pad_m, + pad_n=self.pad_n, + pad_k=self.pad_k, + gfx_arch=self.gfx_arch, + variant=self.variant, + ) + + +# ============================================================================ +# Problem +# ============================================================================ + + +@dataclass +class GemmProblem: + """A single GEMM problem: C[MxN] = A[MxK] @ B[KxN].""" + + M: int + N: int + K: int + + @property + def flops(self) -> float: + return 2.0 * self.M * self.N * self.K + + def to_dict(self) -> Dict[str, int]: + return {"M": self.M, "N": self.N, "K": self.K} + + @classmethod + def from_dict(cls, d: Dict[str, int]) -> "GemmProblem": + return cls(M=int(d["M"]), N=int(d["N"]), K=int(d["K"])) + + +@dataclass +class GemmResult: + output: np.ndarray + time_ms: float + status: int + tflops: float + kernel_name: str + + @property + def success(self) -> bool: + return self.status == 0 + + +# ============================================================================ +# ctypes ABI wrapper +# ============================================================================ + + +class GemmDispatcherLib: + """Thin ctypes wrapper around a compiled GEMM dispatcher .so. + + Supports both the legacy single-kernel ABI (``dispatcher_get_kernel_name``) + and the multi-kernel ABI (``dispatcher_get_kernel_name_at(index, buf, n)``) + so one .so can report a whole batch and be selected by name. + """ + + def __init__(self, so_path: Path): + self._path = Path(so_path) + self._lib = ctypes.CDLL(str(self._path)) + self._has_indexed = hasattr(self._lib, "dispatcher_get_kernel_name_at") + self._setup_functions() + + def _setup_functions(self) -> None: + lib = self._lib + + lib.dispatcher_initialize.argtypes = [] + lib.dispatcher_initialize.restype = ctypes.c_int + + lib.dispatcher_get_kernel_count.argtypes = [] + lib.dispatcher_get_kernel_count.restype = ctypes.c_int + + lib.dispatcher_get_kernel_name.argtypes = [] + lib.dispatcher_get_kernel_name.restype = ctypes.c_char_p + + if self._has_indexed: + lib.dispatcher_get_kernel_name_at.argtypes = [ + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + lib.dispatcher_get_kernel_name_at.restype = ctypes.c_int + + lib.dispatcher_run_gemm.argtypes = [ + ctypes.c_void_p, # A (host) + ctypes.c_void_p, # B (host) + ctypes.c_void_p, # C (host) + ctypes.c_int64, # M + ctypes.c_int64, # N + ctypes.c_int64, # K + ctypes.POINTER(ctypes.c_float), # time_ms + ] + lib.dispatcher_run_gemm.restype = ctypes.c_int + + lib.dispatcher_cleanup.argtypes = [] + lib.dispatcher_cleanup.restype = None + + @property + def path(self) -> Path: + return self._path + + def initialize(self) -> bool: + return self._lib.dispatcher_initialize() == 0 + + def get_kernel_count(self) -> int: + return int(self._lib.dispatcher_get_kernel_count()) + + @property + def kernel_names(self) -> List[str]: + """List every kernel the .so exposes, by index when available.""" + if self._has_indexed: + names: List[str] = [] + count = self.get_kernel_count() + buf = ctypes.create_string_buffer(256) + for i in range(count): + if self._lib.dispatcher_get_kernel_name_at(i, buf, 256) == 0: + names.append(buf.value.decode("utf-8")) + if names: + return names + # Legacy single-kernel fallback. + raw = self._lib.dispatcher_get_kernel_name() + return [raw.decode("utf-8")] if raw else [] + + def run( + self, A: np.ndarray, B: np.ndarray, C: np.ndarray, M: int, N: int, K: int + ) -> Tuple[int, float]: + time_ms = ctypes.c_float(0.0) + status = self._lib.dispatcher_run_gemm( + A.ctypes.data_as(ctypes.c_void_p), + B.ctypes.data_as(ctypes.c_void_p), + C.ctypes.data_as(ctypes.c_void_p), + M, + N, + K, + ctypes.byref(time_ms), + ) + return status, time_ms.value + + def cleanup(self) -> None: + self._lib.dispatcher_cleanup() + + +# ============================================================================ +# GPU runner (constructed from a .so path; loaded only inside a worker) +# ============================================================================ + + +def _fp32_to_bf16_u16(x: np.ndarray) -> np.ndarray: + """Encode fp32 -> bfloat16 bit pattern in a uint16 array (round-to-nearest-even). + + numpy has no native bf16, but the C ABI only cares about the 2-byte memory + layout (sizeof(bf16_t) == 2 == sizeof(uint16)). Truncating the low 16 bits of + the fp32 representation with round-to-nearest-even matches ck_tile's bf16. + """ + u32 = np.ascontiguousarray(x, dtype=np.float32).view(np.uint32) + # round-to-nearest-even: add (lsb-of-kept-bits + 0x7FFF) before truncating + rounding = ((u32 >> 16) & 1) + np.uint32(0x7FFF) + return ((u32 + rounding) >> 16).astype(np.uint16) + + +def _bf16_u16_to_fp32(u16: np.ndarray) -> np.ndarray: + """Decode a uint16 bf16 bit pattern back to fp32 (low 16 mantissa bits zero).""" + return (u16.astype(np.uint32) << 16).view(np.float32) + + +# --------------------------------------------------------------------------- +# fp8 (E4M3) / bf8 (E5M2) -- FNUZ ("NANOO") encoding used by gfx942/MI300. +# +# numpy has no native 8-bit float, and the C ABI only cares about the 1-byte +# memory layout (sizeof(fp8_t) == sizeof(bf8_t) == 1). We carry the value as a +# uint8 bit pattern. As with bf16, the DECODE is the load-bearing half: it must +# return the exact value the device's fp8_t/bf8_t represents for a byte, so the +# NumPy reference multiplies bit-for-bit what the GPU multiplies. The ENCODE only +# needs to land on the nearest representable byte. +# +# FNUZ format (gfx942): bias = 2^(exp_bits-1); the all-1s exponent is a normal +# number (no Inf), the sole NaN is the sign=1/exp=0/mant=0 byte (0x80), and there +# is no negative zero. gfx950/MI350 uses the OCP fp8 format instead; this codec +# targets the gfx942 default and the OCP path needs separate handling. +# --------------------------------------------------------------------------- + + +def _fnuz_decode_table(exp_bits: int, mant_bits: int) -> np.ndarray: + """Build the 256-entry byte -> fp32 value table for an 8-bit FNUZ float.""" + bias = (1 << (exp_bits - 1)) + mant_max = 1 << mant_bits + sign_shift = exp_bits + mant_bits + exp_mask = (1 << exp_bits) - 1 + table = np.zeros(256, dtype=np.float32) + for b in range(256): + sign = (b >> sign_shift) & 1 + exp = (b >> mant_bits) & exp_mask + mant = b & (mant_max - 1) + if exp == 0 and mant == 0: + # +0 (0x00); the negative-zero slot (0x80) is the lone NaN. + table[b] = np.float32(np.nan) if sign else np.float32(0.0) + continue + if exp == 0: + val = (mant / mant_max) * (2.0 ** (1 - bias)) # subnormal + else: + val = (1.0 + mant / mant_max) * (2.0 ** (exp - bias)) # normal + table[b] = np.float32(-val if sign else val) + return table + + +def _fnuz_encode(x: np.ndarray, exp_bits: int, mant_bits: int) -> np.ndarray: + """Encode fp32 -> nearest 8-bit FNUZ float, returned as a uint8 bit pattern.""" + table = _fnuz_decode_table(exp_bits, mant_bits) + sign_byte = np.uint8(1 << (exp_bits + mant_bits)) # 0x80 + + # Positive half (bytes 0..127) holds every non-negative magnitude, sorted. + # Compare in float64: for very large inputs the gap between the two top + # magnitudes is below fp32 resolution, which would tie and mis-saturate. + pos_mag = table[: int(sign_byte)].astype(np.float64) + order = np.argsort(pos_mag) + sorted_mag = pos_mag[order] + sorted_byte = order.astype(np.uint8) + + xf = np.ascontiguousarray(x, dtype=np.float32) + ax = np.abs(xf).astype(np.float64) + # Both neighbours come from the raw insertion point: raw==size saturates to + # the top magnitude (lo==hi), raw==0 pins to zero, otherwise compare the two. + raw = np.searchsorted(sorted_mag, ax) + hi = np.clip(raw, 0, sorted_mag.size - 1) + lo = np.clip(raw - 1, 0, sorted_mag.size - 1) + pick_lo = np.abs(sorted_mag[lo] - ax) <= np.abs(sorted_mag[hi] - ax) + chosen = np.where(pick_lo, lo, hi) + out = sorted_byte[chosen] + + # Apply sign, but never the 0x80 (-0 == NaN) slot: zeros stay +0. + is_zero = sorted_mag[chosen] == 0 + out = np.where((xf < 0) & ~is_zero, out | sign_byte, out) + out = np.where(np.isnan(xf), sign_byte, out) # NaN inputs -> NaN byte + return out.astype(np.uint8).reshape(np.shape(x)) + + +def _fp32_to_fp8_u8(x: np.ndarray) -> np.ndarray: + """Encode fp32 -> fp8 E4M3 (FNUZ) bit pattern in a uint8 array.""" + return _fnuz_encode(x, exp_bits=4, mant_bits=3) + + +def _fp8_u8_to_fp32(u8: np.ndarray) -> np.ndarray: + """Decode an fp8 E4M3 (FNUZ) bit pattern back to fp32.""" + return _fnuz_decode_table(4, 3)[u8.astype(np.intp)] + + +def _fp32_to_bf8_u8(x: np.ndarray) -> np.ndarray: + """Encode fp32 -> bf8 E5M2 (FNUZ) bit pattern in a uint8 array.""" + return _fnuz_encode(x, exp_bits=5, mant_bits=2) + + +def _bf8_u8_to_fp32(u8: np.ndarray) -> np.ndarray: + """Decode a bf8 E5M2 (FNUZ) bit pattern back to fp32.""" + return _fnuz_decode_table(5, 2)[u8.astype(np.intp)] + + +# Output (C) element dtype for an A/B element dtype, mirroring the codegen's +# CommonTypeMappings.get_output_dtype: fp8/bf8 accumulate into fp16, int8 into +# int32, everything else stores in its own dtype. +_OUTPUT_DTYPE = {"fp8": "fp16", "bf8": "fp16", "int8": "int32"} + + +def _output_dtype(dtype: str) -> str: + return _OUTPUT_DTYPE.get(dtype, dtype) + + +def _dtype_from_kernel_name(name: str) -> str: + """Extract the dtype token from a kernel name like ``gemm___...``.""" + parts = name.split("_") + return parts[1] if len(parts) > 1 else "fp16" + + +def _layout_from_kernel_name(name: str) -> str: + """Extract the 3-char layout token (e.g. 'rcr') from a kernel name. + + Name format is ``gemm___...``; each char is 'r' (row-major) + or 'c' (column-major) for operands A, B, C respectively. + """ + parts = name.split("_") + if len(parts) > 2 and len(parts[2]) == 3 and set(parts[2]) <= {"r", "c"}: + return parts[2] + return "rcr" + + +class GpuGemmRunner: + """High-level runner: construct from a .so path, call run(A, B, problem). + + The GEMM ctypes ABI takes HOST pointers and manages GPU memory internally + (hipMalloc/hipMemcpy/hipFree), so this runner stays simple -- it hands + numpy arrays straight to the .so. + """ + + def __init__(self, lib_path: Path): + self.lib = GemmDispatcherLib(lib_path) + if not self.lib.initialize(): + raise RuntimeError(f"Failed to initialize dispatcher .so: {lib_path}") + names = self.lib.kernel_names + self._kernel_name = names[0] if names else "unknown" + + @property + def kernel_name(self) -> str: + return self._kernel_name + + def run( + self, A: np.ndarray, B: np.ndarray, problem: GemmProblem + ) -> GemmResult: + M, N, K = problem.M, problem.N, problem.K + + # Caller passes logical A (MxK) and B (KxN) row-major. The compiled + # kernel dictates both the element dtype and the memory layout of each + # operand (encoded in its name, e.g. gemm_bf16_rcr_...). The C ABI sizes + # its device buffers from sizeof(ADataType) and the kernel computes + # strides from its compiled layout + M,N,K -- so the host buffers must + # be laid out byte-for-byte in the order the kernel expects. + # + # For a 'c' (column-major) operand we transpose so the contiguous host + # buffer's flat memory matches column-major order: + # col-major A (MxK) <=> ascontiguousarray(A.T) (KxM row-major) + # Likewise column-major C (MxN) lands in memory as NxM row-major, so we + # allocate (N,M) and transpose the result back to logical MxN. + dtype = _dtype_from_kernel_name(self._kernel_name) + la, lb, lc = _layout_from_kernel_name(self._kernel_name) + + A_lay = A if la == "r" else A.T + B_lay = B if lb == "r" else B.T + C_shape = (M, N) if lc == "r" else (N, M) + + # Build A/B host buffers in the kernel's element dtype. The encode + # helpers (bf16/fp8/bf8) already force a contiguous float32 source, so an + # outer ascontiguousarray would only add a redundant copy; the native + # numpy dtypes (fp16/int8) still need it. + if dtype == "bf16": + A_h = _fp32_to_bf16_u16(A_lay) + B_h = _fp32_to_bf16_u16(B_lay) + elif dtype == "fp8": + A_h = _fp32_to_fp8_u8(A_lay) + B_h = _fp32_to_fp8_u8(B_lay) + elif dtype == "bf8": + A_h = _fp32_to_bf8_u8(A_lay) + B_h = _fp32_to_bf8_u8(B_lay) + elif dtype == "int8": + A_h = np.ascontiguousarray(A_lay, dtype=np.int8) + B_h = np.ascontiguousarray(B_lay, dtype=np.int8) + else: # fp16 (default) + A_h = np.ascontiguousarray(A_lay, dtype=np.float16) + B_h = np.ascontiguousarray(B_lay, dtype=np.float16) + + # The C buffer's element size must equal sizeof(CDataType): fp8/bf8 + # accumulate into fp16, int8 into int32, otherwise the input dtype. + out_dtype = _output_dtype(dtype) + _C_NP = {"fp16": np.float16, "bf16": np.uint16, "int32": np.int32} + C_h = np.zeros(C_shape, dtype=_C_NP.get(out_dtype, np.float16)) + + status, time_ms = self.lib.run(A_h, B_h, C_h, M, N, K) + + # Decode the output back to a comparable numeric array. + if out_dtype == "bf16": + C_dec = _bf16_u16_to_fp32(C_h) + else: # fp16 / int32 are already directly comparable + C_dec = C_h + C_out = C_dec if lc == "r" else C_dec.T + + tflops = (problem.flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0.0 + return GemmResult( + output=C_out, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self._kernel_name, + ) + + +# ============================================================================ +# Build API: codegen + hipcc -> .so paths (no GPU) +# ============================================================================ + +# AMDGPU codegen flags Tile Engine passes to hipcc for GEMM kernels. These MUST +# match, flag-for-flag, the set the Tile Engine gemm_universal benchmark TU is +# compiled with (projects/composablekernel/CMakeLists.txt) -- they steer inlining +# and register allocation, and because persistent kernels size their grid by +# occupancy, any mismatch produces large perf gaps vs Tile Engine and makes the +# parity comparison no longer apples-to-apples. +# +# Tile Engine's actual GEMM benchmark flags (verified from its compile_commands): +# -fno-offload-uniform-block +# -mllvm --lsr-drop-solution=1 +# -mllvm -enable-post-misched=0 +# -mllvm -amdgpu-early-inline-all=true +# -mllvm -amdgpu-function-calls=false +# -mllvm -amdgpu-coerce-illegal-types=1 (CMake adds this only when the +# compiler accepts it; see below) +# NOTE: -enable-noalias-to-md-conversion=0 is NOT a Tile Engine GEMM flag (it only +# appears in the standalone CK examples/tests), so it deliberately is NOT here. +_TILE_ENGINE_CODEGEN_FLAGS = ( + "-mllvm", "--lsr-drop-solution=1", + "-mllvm", "-enable-post-misched=0", + "-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false", + "-fno-offload-uniform-block", +) + +# Flags Tile Engine's CMake only adds when ``check_cxx_compiler_flag`` passes +# (newer -mllvm options that some clang builds reject). We mirror that probe so +# the bridge stays matched to Tile Engine on every toolchain: the flag is present +# exactly where TE would have it, and absent where TE's CMake would also skip it. +_PROBED_CODEGEN_FLAGS = ( + ("-mllvm", "-amdgpu-coerce-illegal-types=1"), +) + + +@functools.lru_cache(maxsize=None) +def _hipcc_accepts(flag_tuple: Tuple[str, ...]) -> bool: + """Mirror CMake check_cxx_compiler_flag: does hipcc compile a trivial TU with + these flags? Cached so the probe runs at most once per distinct flag set.""" + hipcc = os.environ.get("HIPCC") or shutil.which("hipcc") or "/opt/rocm/bin/hipcc" + try: + with tempfile.TemporaryDirectory() as d: + src = Path(d) / "probe.cpp" + src.write_text("int main(){}\n") + r = subprocess.run( + [hipcc, *flag_tuple, "-c", str(src), "-o", str(Path(d) / "probe.o")], + capture_output=True, timeout=120, + ) + return r.returncode == 0 + except Exception: + return False + + +@functools.lru_cache(maxsize=1) +def _tile_engine_codegen_flags() -> Tuple[str, ...]: + """Tile Engine's GEMM codegen flags plus any probe-gated flags the compiler + accepts -- the exact backend flag set the TE benchmark is built with.""" + flags = list(_TILE_ENGINE_CODEGEN_FLAGS) + for pair in _PROBED_CODEGEN_FLAGS: + if _hipcc_accepts(pair): + flags = list(pair) + flags + return tuple(flags) + + +def _build_compile_jobs( + config: GemmKernelConfig, header: Path +) -> Tuple[Dict[str, Any], Path]: + """Replicate the (validated) compile+link commands from ctypes_utils.""" + root = _cu.get_dispatcher_root() + ck_root = root.parent + build_dir = _cu.get_build_dir() + output_dir = _cu.get_generated_kernels_dir() + ctypes_source = root / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + lib_path = build_dir / "examples" / f"lib{config.name}.so" + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{str(output_dir)}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', + # Match Tile Engine's AMDGPU codegen flags exactly (see + # _tile_engine_codegen_flags). Without them the kernel is compiled with + # different inlining/register allocation, which changes occupancy; + # persistent kernels size their grid by occupancy + # (UniversalGemmKernel::MaxOccupancyGridSize = #CUs x occupancy), so a + # mismatch shows up as large perf gaps vs Tile Engine on persistent tiles. + *_tile_engine_codegen_flags(), + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + job = {"compile_cmd": compile_cmd, "link_cmd": link_cmd, "lib_path": str(lib_path)} + return job, lib_path + + +def setup_multiple_gemm_dispatchers( + configs: List[GemmKernelConfig], + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[Optional[Path]]: + """Codegen + compile each config into its own .so. Returns .so paths. + + This is the build half of the bridge. It touches NO GPU -- pure CPU + codegen + hipcc, run massively in parallel -- and returns only ``Path`` + objects (``None`` for configs that failed to generate/compile), aligned to + the input order. Benchmarking happens later, in an isolated worker. + """ + import sys + + n = len(configs) + results: List[Optional[Path]] = [None] * n + if n == 0: + return results + + max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + # Dedupe identical configs by name; compile once, share the path. + first_index: Dict[str, int] = {} + unique: List[int] = [] + for i, c in enumerate(configs): + key = c.name + if key not in first_index: + first_index[key] = i + unique.append(i) + + codegen_script = _cu.get_codegen_path() + output_dir = _cu.get_generated_kernels_dir() + static_lib = _cu.get_build_dir() / "libck_tile_dispatcher.a" + ctypes_source = ( + _cu.get_dispatcher_root() / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + ) + if not static_lib.exists() or not ctypes_source.exists(): + raise FileNotFoundError( + "Missing static lib or ctypes source required for compilation:\n" + f" {static_lib}\n {ctypes_source}\n" + "Build the dispatcher first (cmake + make)." + ) + + # -- Step 1: parallel codegen (one header per unique config) ----------- + codegen_args = [] + for i in unique: + c = configs[i] + codegen_args.append( + { + "index": i, + "python": sys.executable, + "codegen_script": str(codegen_script), + "output_dir": str(output_dir), + "dtype": c.dtype_a, + "layout": c.layout, + "gpu_target": c.gfx_arch, + "tile_config_json": c.to_codegen_json(), + "hpp_glob_pattern": f"{c.name}.hpp", + # Honor the config's variant so non-standard kernels are codegen'd + # as themselves; the kernel name (and thus hpp_glob_pattern) already + # carries the variant suffix, so a missing/standard value here would + # produce a header whose name never matches the requested pattern. + "variant": c.variant, + } + ) + + if verbose: + print( + f"[gemm-bridge] codegen: {len(codegen_args)} headers " + f"(workers={max_workers})..." + ) + + headers: Dict[int, Path] = {} + with ProcessPoolExecutor(max_workers=max_workers) as ex: + futs = { + ex.submit(_cu._generate_single_kernel_subprocess, a): a["index"] + for a in codegen_args + } + for fut in as_completed(futs): + i = futs[fut] + ok, hdr, err = fut.result() + if ok and hdr: + headers[i] = Path(hdr) + if verbose: + print(f" OK codegen [{i}] {configs[i].name}") + elif verbose: + print(f" FAIL codegen [{i}] {configs[i].name}: {err}") + + # -- Step 2: parallel compile + link ----------------------------------- + compile_jobs = [] + job_index: List[int] = [] + for i in unique: + hdr = headers.get(i) + if hdr is None: + continue + job, _ = _build_compile_jobs(configs[i], hdr) + compile_jobs.append(job) + job_index.append(i) + + if verbose and compile_jobs: + print( + f"[gemm-bridge] compile: {len(compile_jobs)} .so " + f"(workers={max_workers})..." + ) + + with ProcessPoolExecutor(max_workers=max_workers) as ex: + futs = { + ex.submit(_cu._run_hipcc_subprocess, job): job_index[j] + for j, job in enumerate(compile_jobs) + } + for fut in as_completed(futs): + i = futs[fut] + ok, lp, err = fut.result() + if ok and lp: + results[i] = Path(lp) + if verbose: + print(f" OK compile [{i}] {Path(lp).name}") + elif verbose: + print(f" FAIL compile [{i}] {configs[i].name}: {err}") + + # -- Fan the deduped result back out to every input index -------------- + for i, c in enumerate(configs): + if results[i] is None: + results[i] = results[first_index[c.name]] + + if verbose: + ok_count = sum(1 for r in results if r is not None) + print(f"[gemm-bridge] setup complete: {ok_count}/{n} configs -> .so") + + return results + + +# ============================================================================ +# TE sweep config expansion +# ============================================================================ + + +def _expand_range(entry: Dict[str, Any]) -> List[int]: + """Expand a tile_config entry: either {min,max,step} or {values:[...]}.""" + if "values" in entry: + return list(entry["values"]) + lo = int(entry["min"]) + hi = int(entry["max"]) + step = int(entry.get("step", 1)) + return list(range(lo, hi + 1, step)) + + +def _expand_values(entry: Optional[Dict[str, Any]], default: List[Any]) -> List[Any]: + if entry is None: + return list(default) + return list(entry.get("values", default)) + + +def _is_power_of_two(x: int) -> bool: + return x > 0 and (x & (x - 1)) == 0 + + +def expand_sweep( + config_path: str, + arch: str, + dtype: str = "fp16", + layout: str = "rcr", +) -> List[GemmKernelConfig]: + """Expand a Tile Engine GEMM JSON sweep config into GemmKernelConfig list. + + The TE config uses ``tile_config`` (ranges/value-lists for tile, warp and + warp_tile triples) and ``trait_config`` (value-lists for pipeline, + scheduler, epilogue, pad_*, persistent). Every valid combination becomes + one GemmKernelConfig. Invalid combinations are dropped via the dispatcher's + own validator, and duplicates (by .name) are collapsed. + + The signature is controlled by the `dtype` and `layout` arguments (defaults + to fp16 / rcr). + """ + with open(config_path) as f: + cfg = json.load(f) + + tc = cfg.get("tile_config", {}) + tr = cfg.get("trait_config", {}) + + tile_ms = _expand_range(tc["tile_m"]) + tile_ns = _expand_range(tc["tile_n"]) + tile_ks = _expand_range(tc["tile_k"]) + wave_ms = _expand_range(tc["warp_m"]) # TE "warp" == wave count + wave_ns = _expand_range(tc["warp_n"]) + wave_ks = _expand_range(tc["warp_k"]) + wt_ms = _expand_range(tc["warp_tile_m"]) + wt_ns = _expand_range(tc["warp_tile_n"]) + wt_ks = _expand_range(tc["warp_tile_k"]) + + pipelines = _expand_values(tr.get("pipeline"), ["compv3"]) + schedulers = _expand_values(tr.get("scheduler"), ["intrawave"]) + epilogues = _expand_values(tr.get("epilogue"), ["cshuffle"]) + pad_ms = _expand_values(tr.get("pad_m"), [False]) + pad_ns = _expand_values(tr.get("pad_n"), [False]) + pad_ks = _expand_values(tr.get("pad_k"), [False]) + persistents = _expand_values(tr.get("persistent"), [False]) + + la, lb, lc = layout[0], layout[1], layout[2] + + configs: List[GemmKernelConfig] = [] + seen: set = set() + for ( + tm, + tn, + tk, + wm, + wn, + wk, + wtm, + wtn, + wtk, + pipe, + sched, + epi, + pm, + pn, + pk, + persist, + ) in itertools.product( + tile_ms, + tile_ns, + tile_ks, + wave_ms, + wave_ns, + wave_ks, + wt_ms, + wt_ns, + wt_ks, + pipelines, + schedulers, + epilogues, + pad_ms, + pad_ns, + pad_ks, + persistents, + ): + c = GemmKernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=_output_dtype(dtype), + dtype_acc=("int32" if dtype == "int8" else "fp32"), + layout_a=_LAYOUT_WORD[la], + layout_b=_LAYOUT_WORD[lb], + layout_c=_LAYOUT_WORD[lc], + tile_m=tm, + tile_n=tn, + tile_k=tk, + wave_m=wm, + wave_n=wn, + wave_k=wk, + warp_tile_m=wtm, + warp_tile_n=wtn, + warp_tile_k=wtk, + pipeline=pipe, + scheduler=sched, + epilogue=epi, + pad_m=bool(pm), + pad_n=bool(pn), + pad_k=bool(pk), + persistent=bool(persist), + gfx_arch=arch, + ) + if c.name in seen: + continue + val = _cu.validate_kernel_config(c.to_ctypes_config()) + if not val.is_valid: + continue + # Tile/CShuffle correctness gate (mirrors unified_gemm_codegen's + # TileConfig.is_valid + the power-of-two repeat rule; the ctypes + # validate_kernel_config above does NOT enforce either). A block tile must + # split evenly across its waves -- tile % (wave * warp_tile) == 0 -- and + # the CShuffle epilogue stores the accumulator through LDS in power-of-two + # MRepeat/NRepeat chunks, so the per-wave repeat must be a power of two. + # Tiles that violate either still compile but produce numerically WRONG + # results at runtime. Observed on MI350 for tile_m=192 (MRepeat=3) and + # tile_n=192 (e.g. 64x192x64_1x4x1, 192 not divisible by 4*32) -- both + # verified incorrect on the bridge and Tile Engine. Power-of-two tiles + # (64/128/256) are unaffected. + m_div = wm * wtm + n_div = wn * wtn + if m_div <= 0 or n_div <= 0 or tm % m_div != 0 or tn % n_div != 0: + continue + if not _is_power_of_two(tm // m_div) or not _is_power_of_two(tn // n_div): + continue + seen.add(c.name) + configs.append(c) + + return configs diff --git a/projects/composablekernel/dispatcher/tests/test_gemm_parity.py b/projects/composablekernel/dispatcher/tests/test_gemm_parity.py new file mode 100644 index 000000000000..b9d1cd1cd9ab --- /dev/null +++ b/projects/composablekernel/dispatcher/tests/test_gemm_parity.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""GEMM bridge parity regression: Dispatcher GPU output vs NumPy reference. + +This is the in-tree, reproducible version of the ad-hoc ``parity/`` sweep used to +validate the Tile Engine -> Dispatcher GEMM bridge. For each (dtype, layout) the +bridge supports it codegens + hipcc-compiles a kernel, runs it through +``GpuGemmRunner``, and compares the result to a NumPy reference across a square, a +rectangular, and an awkward (non-tile-aligned) problem shape. + +Parity is checked as a GLOBAL relative error -- ``max|gpu - ref| / max|ref|`` -- +not per-element: K-length accumulation of zero-mean inputs produces near-zero +entries whose per-element ratios explode and carry no signal. + +The whole suite is GPU-gated: it skips cleanly (not fails) when hipcc, the +dispatcher static lib, or a GPU is unavailable, so CPU-only CI stays green while +GPU runners get real end-to-end coverage. The pure host-side helpers are covered +separately and cheaply by ``test_gemm_utils.py``. + +Run: + python3 -m pytest tests/test_gemm_parity.py -v # discovery / CI + python3 tests/test_gemm_parity.py # readable table +""" + +import os +import sys +import shutil +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +import numpy as np # noqa: E402 + +from gemm_utils import ( # noqa: E402 + GemmKernelConfig, + GemmProblem, + GpuGemmRunner, + setup_multiple_gemm_dispatchers, + _fp32_to_bf16_u16, + _bf16_u16_to_fp32, + _fp32_to_fp8_u8, + _fp8_u8_to_fp32, + _fp32_to_bf8_u8, + _bf8_u8_to_fp32, + _output_dtype, +) +from ctypes_utils import detect_gpu_arch, get_build_dir # noqa: E402 + +# (dtype, layout) surface the regular bridge supports. Column-major C is rejected +# by ck_tile's universal GEMM at build, so every layout keeps row-major C, which +# leaves exactly the four A/B combinations below. Every dtype covers all four. +# +# fp16/bf16 are the PR #8479 surface; fp8 (E4M3), bf8 (E5M2) and int8 are the +# remaining dtypes TE's plain GEMM has MFMA warp tiles for (fp8/bf8 -> fp16 out, +# int8 -> int32 out). int8 only has warp tiles on gfx942; on other arches its +# kernels simply fail to build and the case skips (handled below). +_FLOAT_DTYPES = ("fp16", "bf16", "fp8", "bf8") +_INT_DTYPES = ("int8",) +_LAYOUTS = ("rcr", "rrr", "ccr", "crr") +_CASES = [ + (dt, lay) for dt in (*_FLOAT_DTYPES, *_INT_DTYPES) for lay in _LAYOUTS +] + +# Padded default algorithm: pad_* all True so M/N need not divide the tile, which +# is what lets the awkward shape below pass. K must still be a multiple of 8 for +# the fp16/bf16 vectorized contiguous-reduction load, so every K here is divisible +# by 8. +_ALGO = dict( + tile_m=128, tile_n=128, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave", epilogue="cshuffle", + pad_m=True, pad_n=True, pad_k=True, +) + +# (name, M, N, K). 'awkward' deliberately uses M, N that do not divide the 128 +# tile to exercise padding; K stays divisible by 8. +_SHAPES = [ + ("square", 512, 512, 512), + ("rectangular", 1024, 512, 256), + ("awkward", 257, 129, 512), +] + +# Global-relative-error gates. fp16 measured ~3-4e-4 and bf16 ~8e-3 on gfx942. +# fp8/bf8 are far coarser (3- and 2-bit mantissa) so their gates are looser; int8 +# is an exact integer accumulation so it must match bit-for-bit. The fp8/bf8 +# gates are first-cut headroom values and may want tightening once measured on a +# GPU. +_TOL = { + "fp16": 2e-3, + "bf16": 1.5e-2, + "fp8": 1.5e-1, + "bf8": 3.0e-1, + "int8": 0.0, +} + +_LAYOUT_WORD = {"r": "row", "c": "col"} + + +def _emulate_input(x: np.ndarray, dtype: str) -> np.ndarray: + """Round an fp32 operand to the kernel's storage dtype so the CPU reference + multiplies exactly what the GPU does. int8 inputs are already integral.""" + if dtype == "bf16": + return _bf16_u16_to_fp32(_fp32_to_bf16_u16(x)) + if dtype == "fp8": + return _fp8_u8_to_fp32(_fp32_to_fp8_u8(x)) + if dtype == "bf8": + return _bf8_u8_to_fp32(_fp32_to_bf8_u8(x)) + if dtype == "int8": + return x.astype(np.float64) # exact; widened to avoid product overflow + return x.astype(np.float16).astype(np.float32) + + +def _emulate_output(c: np.ndarray, out_dtype: str) -> np.ndarray: + """Round the fp32 accumulator to the kernel's C storage dtype.""" + if out_dtype == "bf16": + return _bf16_u16_to_fp32(_fp32_to_bf16_u16(c)) + if out_dtype == "int32": + return c # integer accumulation is exact + return c.astype(np.float16).astype(np.float32) # fp16 + + +def _make_inputs(dtype, M, N, K, rng): + """Random A (MxK), B (KxN) for a dtype: floats for the float dtypes, small + integers for int8 (kept small so the int32 accumulation cannot overflow).""" + if dtype == "int8": + A = rng.integers(-4, 5, size=(M, K)).astype(np.float32) + B = rng.integers(-4, 5, size=(K, N)).astype(np.float32) + return A, B + A = (rng.standard_normal((M, K)) * 0.1).astype(np.float32) + B = (rng.standard_normal((K, N)) * 0.1).astype(np.float32) + return A, B + + +def _reference(A, B, dtype): + """NumPy reference matching the kernel: round inputs to the storage dtype, + accumulate (fp32 for floats / exact int for int8), then round to C dtype.""" + out_dtype = _output_dtype(dtype) + acc = _emulate_input(A, dtype) @ _emulate_input(B, dtype) + ref = _emulate_output(acc, out_dtype) + return ref.astype(np.int32) if out_dtype == "int32" else ref + + +def _config(dtype: str, layout: str, arch: str) -> GemmKernelConfig: + la, lb, lc = layout + return GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, + dtype_c=_output_dtype(dtype), + dtype_acc=("int32" if dtype == "int8" else "fp32"), + layout_a=_LAYOUT_WORD[la], layout_b=_LAYOUT_WORD[lb], layout_c=_LAYOUT_WORD[lc], + gfx_arch=arch, **_ALGO, + ) + + +def _max_rel(out: np.ndarray, ref: np.ndarray) -> float: + denom = float(np.max(np.abs(ref))) + 1e-12 + return float(np.max(np.abs(out - ref))) / denom + + +def _gpu_environment_reason(): + """Return None if the bridge can build+run here, else a human-readable reason + to skip.""" + if not Path("/opt/rocm/bin/hipcc").exists(): + return "hipcc not found at /opt/rocm/bin/hipcc" + if not (get_build_dir() / "libck_tile_dispatcher.a").exists(): + return "dispatcher static lib (libck_tile_dispatcher.a) not built" + if shutil.which("rocminfo") is None: + return "rocminfo not found (no ROCm runtime / GPU)" + return None + + +class GemmBridgeParity(unittest.TestCase): + """End-to-end GPU-vs-NumPy parity across the bridge's dtype/layout surface.""" + + arch = None + built = {} # (dtype, layout) -> Path(.so) + build_failures = {} + + @classmethod + def setUpClass(cls): + reason = _gpu_environment_reason() + if reason: + raise unittest.SkipTest(reason) + cls.arch = detect_gpu_arch() + + configs = [_config(dt, lay, cls.arch) for dt, lay in _CASES] + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + for (dt, lay), so in zip(_CASES, so_paths): + if so is None: + cls.build_failures[(dt, lay)] = "codegen/hipcc returned no .so" + else: + cls.built[(dt, lay)] = so + + if not cls.built: + raise unittest.SkipTest( + f"no bridge kernels built on {cls.arch} " + f"(failures: {cls.build_failures})" + ) + + def _run_case(self, dtype, layout, shape): + so = self.built.get((dtype, layout)) + if so is None: + self.skipTest( + f"{dtype}/{layout} did not build on {self.arch}: " + f"{self.build_failures.get((dtype, layout))}" + ) + + _, M, N, K = shape + problem = GemmProblem(M=M, N=N, K=K) + rng = np.random.default_rng(42) + A, B = _make_inputs(dtype, M, N, K, rng) + + runner = GpuGemmRunner(lib_path=so) + # The .so is the contract endpoint: the name it reports must be the config + # name that drove codegen + the force-include build. The kernel name keys + # off the input dtype (dtype_a), not the C/acc dtype. + self.assertEqual(runner.kernel_name, _config(dtype, layout, self.arch).name) + + result = runner.run(A, B, problem) + self.assertTrue( + result.success, + f"{dtype}/{layout} {shape[0]} run failed (status {result.status})", + ) + + ref = _reference(A, B, dtype) + max_rel = _max_rel(result.output.astype(np.float64), ref.astype(np.float64)) + self.assertLessEqual( + max_rel, _TOL[dtype], + f"{dtype}/{layout} {shape[0]} max_rel={max_rel:.2e} > {_TOL[dtype]:.0e}", + ) + + +def _add_parity_tests(): + """Generate one test method per (case, shape) so failures pinpoint exactly + which dtype/layout/shape regressed.""" + for dtype, layout in _CASES: + for shape in _SHAPES: + shape_name = shape[0] + + def _method(self, dtype=dtype, layout=layout, shape=shape): + self._run_case(dtype, layout, shape) + + _method.__name__ = f"test_{dtype}_{layout}_{shape_name}" + _method.__doc__ = f"{dtype} {layout} {shape_name} {shape[1:]} parity" + setattr(GemmBridgeParity, _method.__name__, _method) + + +_add_parity_tests() + + +def _main() -> int: + """Readable table run (mirrors test_fmha_parity.py's report style).""" + reason = _gpu_environment_reason() + if reason: + print(f"SKIP: {reason}") + return 0 + + arch = detect_gpu_arch() + print("=" * 78) + print(f"GEMM Bridge Parity: Dispatcher (GPU {arch}) vs NumPy reference") + print("=" * 78) + + configs = [_config(dt, lay, arch) for dt, lay in _CASES] + print(f" Building {len(configs)} bridge kernels (codegen + hipcc)...") + so_paths = setup_multiple_gemm_dispatchers(configs, verbose=False) + + print(f"\n {'case':<12} {'shape':<12} {'tflops':>9} {'max_rel':>10} {'tol':>8} {'':>6}") + print(" " + "-" * 60) + + rng = np.random.default_rng(42) + total = 0 + passed = 0 + for (dtype, layout), so in zip(_CASES, so_paths): + tag = f"{dtype}/{layout}" + if so is None: + print(f" {tag:<12} {'-':<12} {'BUILD FAILED':>35}") + total += len(_SHAPES) + continue + runner = GpuGemmRunner(lib_path=so) + for sname, M, N, K in _SHAPES: + total += 1 + problem = GemmProblem(M=M, N=N, K=K) + A, B = _make_inputs(dtype, M, N, K, rng) + result = runner.run(A, B, problem) + if not result.success: + print(f" {tag:<12} {sname:<12} {'RUN FAILED':>9} status={result.status}") + continue + ref = _reference(A, B, dtype) + mr = _max_rel(result.output.astype(np.float64), ref.astype(np.float64)) + ok = mr <= _TOL[dtype] + passed += ok + print(f" {tag:<12} {sname:<12} {result.tflops:>9.1f} " + f"{mr:>10.2e} {_TOL[dtype]:>8.0e} {'PASS' if ok else 'FAIL':>6}") + + print("\n" + "=" * 78) + print(f" {passed}/{total} parity checks passed") + print("=" * 78) + return 0 if passed == total else 1 + + +if __name__ == "__main__": + # Default to the readable table; `-m pytest` / `unittest` use the generated + # test methods instead. + if os.environ.get("GEMM_PARITY_UNITTEST"): + unittest.main() + else: + sys.exit(_main()) diff --git a/projects/composablekernel/dispatcher/tests/test_gemm_utils.py b/projects/composablekernel/dispatcher/tests/test_gemm_utils.py new file mode 100644 index 000000000000..8ed188932e0b --- /dev/null +++ b/projects/composablekernel/dispatcher/tests/test_gemm_utils.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""CPU-only unit tests for python/gemm_utils.py. + +Locks in the bit-level helpers that the TE -> Dispatcher GEMM bridge relies on: + * bf16 <-> uint16 encoding (round-to-nearest-even), since numpy has no native + bf16 and the runner carries bf16 as a uint16 bit pattern. + * fp8 (E4M3) / bf8 (E5M2) FNUZ <-> uint8 encoding, used for the gfx942 8-bit + float surface. The decode must be exact to the device format; the encode + only needs to land on the nearest representable byte. + * dtype / layout parsing from the compiled kernel name, which drives how the + runner lays out host buffers. + +No GPU is touched -- all functions under test are pure host-side logic. +Run: python3 -m pytest tests/test_gemm_utils.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +import numpy as np # noqa: E402 + +from gemm_utils import ( # noqa: E402 + GemmKernelConfig, + _fp32_to_bf16_u16, + _bf16_u16_to_fp32, + _fp32_to_fp8_u8, + _fp8_u8_to_fp32, + _fp32_to_bf8_u8, + _bf8_u8_to_fp32, + _fnuz_decode_table, + _output_dtype, + _dtype_from_kernel_name, + _layout_from_kernel_name, +) + + +class TestBf16Encoding(unittest.TestCase): + """bf16 = top 16 bits of fp32 with round-to-nearest-even.""" + + def test_exactly_representable_roundtrip(self): + # Values whose low 16 fp32 mantissa bits are zero are exact in bf16. + exact = np.array([0.0, 1.0, -1.0, 2.0, 0.5, -0.5, 4.0, 256.0], + dtype=np.float32) + out = _bf16_u16_to_fp32(_fp32_to_bf16_u16(exact)) + np.testing.assert_array_equal(out, exact) + + def test_roundtrip_within_bf16_tolerance(self): + rng = np.random.default_rng(0) + x = (rng.standard_normal(10000) * 100.0).astype(np.float32) + out = _bf16_u16_to_fp32(_fp32_to_bf16_u16(x)) + # bf16 has 8 bits of significand -> relative error <= 2^-8. + rel = np.abs(out - x) / (np.abs(x) + 1e-30) + self.assertLessEqual(float(rel.max()), 2.0 ** -8) + + def test_round_to_nearest_even_ties(self): + # Tie halfway between bf16 1.0 (0x3F80, even) and 0x3F81 (odd): + # fp32 0x3F808000 must round DOWN to the even neighbor 0x3F80. + tie_down = np.array([0x3F808000], dtype=np.uint32).view(np.float32) + self.assertEqual(int(_fp32_to_bf16_u16(tie_down)[0]), 0x3F80) + # Tie halfway between 0x3F81 (odd) and 0x3F82 (even): + # fp32 0x3F818000 must round UP to the even neighbor 0x3F82. + tie_up = np.array([0x3F818000], dtype=np.uint32).view(np.float32) + self.assertEqual(int(_fp32_to_bf16_u16(tie_up)[0]), 0x3F82) + + def test_special_values(self): + inf = np.array([np.inf, -np.inf], dtype=np.float32) + out = _bf16_u16_to_fp32(_fp32_to_bf16_u16(inf)) + self.assertTrue(np.isinf(out[0]) and out[0] > 0) + self.assertTrue(np.isinf(out[1]) and out[1] < 0) + + nan = np.array([np.nan], dtype=np.float32) + out_nan = _bf16_u16_to_fp32(_fp32_to_bf16_u16(nan)) + self.assertTrue(np.isnan(out_nan[0])) + + def test_dtype_and_size(self): + u16 = _fp32_to_bf16_u16(np.zeros(4, dtype=np.float32)) + self.assertEqual(u16.dtype, np.uint16) + self.assertEqual(u16.itemsize, 2) # must match sizeof(bf16_t) on device + + +class TestFp8Bf8Encoding(unittest.TestCase): + """fp8 E4M3 / bf8 E5M2 in the FNUZ format used by gfx942. + + The decode is the load-bearing half (it must equal the device value for a + byte); the encode must land on the nearest representable byte and saturate. + """ + + def test_format_ranges(self): + # FNUZ maxima: E4M3 -> 2^7 * 1.875 = 240; E5M2 -> 2^15 * 1.75 = 57344. + t43 = _fnuz_decode_table(4, 3) + t52 = _fnuz_decode_table(5, 2) + self.assertEqual(float(np.nanmax(t43)), 240.0) + self.assertEqual(float(np.nanmin(t43)), -240.0) + self.assertEqual(float(np.nanmax(t52)), 57344.0) + self.assertEqual(float(np.nanmin(t52)), -57344.0) + + def test_zero_and_nan_slots(self): + # 0x00 is +0; the negative-zero slot 0x80 is the lone NaN (FNUZ). + for tab in (_fnuz_decode_table(4, 3), _fnuz_decode_table(5, 2)): + self.assertEqual(float(tab[0x00]), 0.0) + self.assertTrue(np.isnan(tab[0x80])) + + def test_exactly_representable_roundtrip(self): + exact = np.array([0.0, 0.5, 1.0, -1.0, 2.0, -2.0, 1.5, -0.25, 4.0, 8.0], + dtype=np.float32) + np.testing.assert_array_equal( + _fp8_u8_to_fp32(_fp32_to_fp8_u8(exact)), exact) + np.testing.assert_array_equal( + _bf8_u8_to_fp32(_fp32_to_bf8_u8(exact)), exact) + + def test_decode_is_consistent_with_encode(self): + # The parity contract: ref multiplies decode(encode(x)), so the pair must + # be self-consistent and every encoded byte must decode finite. + rng = np.random.default_rng(1) + x = (rng.standard_normal(5000) * 0.1).astype(np.float32) + for enc, dec in ((_fp32_to_fp8_u8, _fp8_u8_to_fp32), + (_fp32_to_bf8_u8, _bf8_u8_to_fp32)): + d = dec(enc(x)) + self.assertTrue(np.all(np.isfinite(d))) + + def test_saturates_no_inf(self): + # FNUZ has no infinity: huge magnitudes clamp to the finite max. + big = np.array([1e30, -1e30], dtype=np.float32) + self.assertEqual(float(_fp8_u8_to_fp32(_fp32_to_fp8_u8(big))[0]), 240.0) + self.assertEqual(float(_bf8_u8_to_fp32(_fp32_to_bf8_u8(big))[1]), -57344.0) + + def test_dtype_and_size(self): + for enc in (_fp32_to_fp8_u8, _fp32_to_bf8_u8): + u8 = enc(np.zeros(4, dtype=np.float32)) + self.assertEqual(u8.dtype, np.uint8) + self.assertEqual(u8.itemsize, 1) # must match sizeof(fp8_t/bf8_t) + + +class TestOutputDtype(unittest.TestCase): + """Output (C) element dtype must mirror the codegen's get_output_dtype.""" + + def test_mapping(self): + self.assertEqual(_output_dtype("fp16"), "fp16") + self.assertEqual(_output_dtype("bf16"), "bf16") + self.assertEqual(_output_dtype("fp8"), "fp16") + self.assertEqual(_output_dtype("bf8"), "fp16") + self.assertEqual(_output_dtype("int8"), "int32") + + +class TestKernelNameParsing(unittest.TestCase): + """The runner reads dtype + layout straight from the compiled .so name.""" + + _NAME = ("gemm_bf16_rcr_compv3_cshuffle_intrawave_" + "False_False_False_False_64x64x64_4x1x1_16x16x16") + + def test_dtype_from_name(self): + self.assertEqual(_dtype_from_kernel_name(self._NAME), "bf16") + self.assertEqual( + _dtype_from_kernel_name("gemm_fp16_rrr_compv4_cshuffle_intrawave"), + "fp16", + ) + + def test_dtype_fallback(self): + # Malformed / single-token name falls back to fp16. + self.assertEqual(_dtype_from_kernel_name("gemm"), "fp16") + + def test_layout_from_name(self): + self.assertEqual(_layout_from_kernel_name(self._NAME), "rcr") + for lay in ("rrr", "ccr", "crr", "rcc"): + name = f"gemm_fp16_{lay}_compv3_cshuffle_intrawave" + self.assertEqual(_layout_from_kernel_name(name), lay) + + def test_layout_fallback(self): + # A token that is not a 3-char r/c string falls back to rcr. + self.assertEqual( + _layout_from_kernel_name("gemm_fp16_xyz_compv3"), "rcr" + ) + self.assertEqual(_layout_from_kernel_name("gemm"), "rcr") + + +class TestConfigNameContract(unittest.TestCase): + """GemmKernelConfig.name is the single source of truth tying config -> + codegen -> runtime; parsing it back must recover dtype and layout.""" + + def test_name_roundtrips_through_parsers(self): + for dtype in ("fp16", "bf16", "fp8", "bf8", "int8"): + for la, lb, lc in (("row", "col", "row"), + ("row", "row", "row"), + ("col", "col", "row"), + ("col", "row", "row")): + cfg = GemmKernelConfig( + dtype_a=dtype, dtype_b=dtype, dtype_c=_output_dtype(dtype), + dtype_acc=("int32" if dtype == "int8" else "fp32"), + layout_a=la, layout_b=lb, layout_c=lc, + ) + name = cfg.name + self.assertEqual(_dtype_from_kernel_name(name), dtype) + self.assertEqual(_layout_from_kernel_name(name), cfg.layout) + + +if __name__ == "__main__": + unittest.main() diff --git a/projects/composablekernel/tile_engine/ops/gemm/README.md b/projects/composablekernel/tile_engine/ops/gemm/README.md index 5e0bae70806d..3c497da7cef0 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/README.md +++ b/projects/composablekernel/tile_engine/ops/gemm/README.md @@ -6,6 +6,7 @@ The CK Tile Engine GEMM module provides a comprehensive system for generating, b ## Table of Contents +0. [Dispatcher Bridge Workflow](#dispatcher-bridge-workflow) 1. [Build System Architecture](#build-system-architecture) 2. [Build Instructions](#build-instructions) 3. [Running Benchmarks](#running-benchmarks) @@ -16,6 +17,145 @@ The CK Tile Engine GEMM module provides a comprehensive system for generating, b 8. [Troubleshooting](#troubleshooting) 9. [Performance Tips](#performance-tips) +## Dispatcher Bridge Workflow + +The **Dispatcher bridge** is the recommended path for sweeping and benchmarking +GEMM kernels. Instead of building monolithic or per-kernel executables through +CMake, Tile Engine expands a sweep config into shared `GemmKernelConfig` objects +and hands them to the Dispatcher, which codegens and compiles each into its own +`.so`. The kernel name produced by the bridge is byte-for-byte identical to the +codegen `KERNEL_NAME`, so the bridge runs exactly the same kernels the native +Tile Engine does — it only swaps the harness. + +### Scripts + +| Script | Role | +|---|---| +| `gemm_full_benchmark.py` | Driver: compile (Phase 1) → load problems (Phase 2) → benchmark across all visible GPUs (Phase 3). | +| `run_one_gemm_kernel.py` | Disposable worker: loads one `.so` in an isolated subprocess and times it. A GPU fault kills only the worker. | + +### Folder layout + +The bridged regular-GEMM path follows the same op-root convention as the merged +`fmha/` and `grouped_conv/` bridges — driver + worker + a flat `configs/` at the +op root: + +``` +gemm/ +├── gemm_full_benchmark.py # bridge driver (op root) +├── run_one_gemm_kernel.py # disposable per-kernel worker (op root) +├── configs/ # bridged gemm_universal sweep configs (flat) +├── gemm_instance_builder.py # shared generator for the non-bridged variants +├── gemm_benchmark.{py,hpp}, gemm_common.hpp, gemm_profiler.hpp # shared harness +├── gemm_multi_d/ gemm_preshuffle/ grouped_gemm/ # legacy variants +└── README.md +``` + +`configs/` ships example sweep configs: + +- `default_ci_config.json` — small CI-sized sweep (the driver's default when no + config is passed). +- `default_config.json` — full sweep. +- `user_provided_config.json` — scratch space for custom sweeps. +- `example_problems.json` — example M/N/K problem set (used when `--problems` + is omitted). + +> The JSON used by **nightly** tests is intended to drop into the same +> `configs/` directory and be selected with a positional config — no driver +> changes needed. + +The not-yet-bridged variants (`gemm_multi_d/`, `gemm_preshuffle/`, +`grouped_gemm/`) keep their own per-variant `configs/` directories; the driver +selects them with `--variant`. + +### Running + +```bash +cd tile_engine/ops/gemm + +# Default: gemm_universal variant, its CI sweep + example problems, +# auto-detect and use all visible GPUs. +python gemm_full_benchmark.py + +# Full sweep, fp16/rcr, restricted to 4 GPUs, custom output: +python gemm_full_benchmark.py --variant gemm_universal \ + configs/default_config.json \ + --dtype fp16 --layout rcr --devices 4 --csv gemm_results.csv + +# Specific GPU ids and a custom problem file: +python gemm_full_benchmark.py --devices 0,2,5 \ + --problems configs/example_problems.json + +# Correctness mode: check every kernel against an fp32 numpy reference. +python gemm_full_benchmark.py --verify --max-kernels 8 +``` + +### Liveness vs correctness (`--verify`) + +By default a measurement is reported `OK` purely on **liveness** — the kernel +ran and produced a non-zero output (`ZERO` otherwise). It is *not* a correctness +check: a numerically wrong but non-zero result still reads `OK`. Pass `--verify` +to have each worker compare its output against an fp32 numpy reference +(`A @ B`) using the global relative metric `max|out - ref| / max|ref|`. With +`--verify`, results read `VERIFY` (within `--verify-tol`, default `2e-2`) or +`MISMATCH` (counted as a failure), and the `max_rel` / `verified` columns are +populated in the CSV. This gives self-contained per-kernel confidence; the +broader numeric parity against native Tile Engine remains a separate task. + +### Multi-GPU parallelism + +Phase 3 fans the `(kernel × problem)` work out across **every visible GPU** in +parallel. One worker thread per device pulls batches from a shared queue and +spawns a disposable subprocess pinned with `HIP_VISIBLE_DEVICES`, so an N-GPU box +benchmarks roughly N× faster while keeping per-batch fault isolation. Devices are +auto-detected (`HIP_VISIBLE_DEVICES`, then `rocm-smi`/`amd-smi`); override with +`--devices`. This supersedes the serial-GPU design inherited from grouped_conv. + +### Supported surface + +| Axis | Supported | +|---|---| +| dtype | `fp16`, `bf16` | +| layout | `rcr`, `rrr`, `crr`, `ccr` (row-major C only — ck_tile rejects column-major C at build) | + +### Variant scope + +The bridge is **one shared, variant-aware driver** (`gemm_full_benchmark.py` + +`run_one_gemm_kernel.py`), not a per-variant copy of the driver. The bridged +regular-GEMM path (`gemm_universal`) uses the op-root `configs/`; `--variant` +selects a not-yet-bridged variant's own `configs/` subdirectory. + +What that means for this PR: + +- **Only `gemm_universal` is wired and validated through the bridge here.** It is + the foundation variant; the dispatcher codegen path is exercised and parity- + checked for it alone. +- The `gemm_multi_d/`, `gemm_preshuffle/`, and `grouped_gemm/` `configs/` + directories are **scaffolding** that follows the per-variant convention so the + layout is ready. `--variant` will select them, but the bridge does **not** yet + produce correct kernels for those variants on this PR — do not treat their + presence as working support. +- Grouped GEMM and stream-K go through **separate bridge efforts** (stream-K in + #8136, grouped GEMM on its own branch), not this PR. + +### Removal note + +The legacy regular-GEMM standalone build path has been **removed**, and the +`gemm_universal/` folder is gone entirely. The per-config benchmark generator and +driver (`gemm_universal_instance_builder.py`, `gemm_universal_benchmark.py`, +`gemm_universal_benchmark*.{cpp,hpp}`, and `gemm_universal/CMakeLists.txt`) no +longer exist; its sweep configs were promoted to the op-root `configs/` directory +(matching the `fmha/` and `grouped_conv/` bridge convention) and are consumed by +the bridge. Regular GEMM now runs exclusively through the Dispatcher bridge +workflow above (`gemm_full_benchmark.py` / `run_one_gemm_kernel.py`). The other +variants (`gemm_multi_d/`, `gemm_preshuffle/`, `grouped_gemm/`) still use the +shared `gemm_instance_builder.py` generator. + +The build-system, build-instruction, and benchmark-execution sections below +describe that removed standalone path and are retained only as historical +reference for the non-bridged variants; the `benchmark_gemm_universal_*` targets +they mention are no longer produced. + ## Build System Architecture ### Individual Kernel Compilation (New Approach) @@ -171,8 +311,13 @@ The system uses JSON configuration files to specify kernel parameters: ### Python Scripts -#### gemm_universal_instance_builder.py -**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files. +#### gemm_instance_builder.py +**Purpose**: Shared kernel instance generator used by the non-bridged variants +(`gemm_multi_d`, `gemm_preshuffle`, `grouped_gemm`). Creates C++ kernel +implementations based on configuration files. + +> The regular-GEMM subclass `gemm_universal/gemm_universal_instance_builder.py` +> has been removed; regular GEMM now goes through the Dispatcher bridge. **Key Features**: - Generates individual kernel header files for separate compilation @@ -180,16 +325,6 @@ The system uses JSON configuration files to specify kernel parameters: - Validates tile configurations for correctness - Creates CMake integration files -**Usage**: -```bash -python gemm_universal_instance_builder.py \ - --working_path ./generated \ - --datatype fp16 \ - --layout rcr \ - --config_json configs/user_provided_config.json \ - --gen_all_individual -``` - #### gemm_instance_builder_parallel.py **Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations. @@ -225,14 +360,6 @@ python test_validation.py - Trait combination validation - Full tile configuration validation -#### gemm_universal_benchmark.py -**Purpose**: Python script for running and analyzing GEMM benchmarks. - -**Features**: -- Automated benchmark execution -- Performance data collection -- Result analysis and reporting - #### json_config.py **Purpose**: Configuration file parsing and management. diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_ci_config.json b/projects/composablekernel/tile_engine/ops/gemm/configs/default_ci_config.json similarity index 98% rename from projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_ci_config.json rename to projects/composablekernel/tile_engine/ops/gemm/configs/default_ci_config.json index 38376a410b01..a2b83334245e 100644 --- a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_ci_config.json +++ b/projects/composablekernel/tile_engine/ops/gemm/configs/default_ci_config.json @@ -32,17 +32,17 @@ }, "warp_tile_m": { "values": [ - 16 + 32 ] }, "warp_tile_n": { "values": [ - 16 + 32 ] }, "warp_tile_k": { "values": [ - 32 + 16 ] } }, diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_config.json b/projects/composablekernel/tile_engine/ops/gemm/configs/default_config.json similarity index 100% rename from projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/default_config.json rename to projects/composablekernel/tile_engine/ops/gemm/configs/default_config.json diff --git a/projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json b/projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json new file mode 100644 index 000000000000..4be0c5a82379 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/configs/example_problems.json @@ -0,0 +1,9 @@ +{ + "problems": [ + {"M": 512, "N": 512, "K": 512}, + {"M": 1024, "N": 1024, "K": 1024}, + {"M": 2048, "N": 2048, "K": 2048}, + {"M": 1024, "N": 512, "K": 256}, + {"M": 4096, "N": 4096, "K": 4096} + ] +} diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json b/projects/composablekernel/tile_engine/ops/gemm/configs/user_provided_config.json similarity index 100% rename from projects/composablekernel/tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json rename to projects/composablekernel/tile_engine/ops/gemm/configs/user_provided_config.json diff --git a/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py new file mode 100644 index 000000000000..f6bdecfae56c --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/gemm_full_benchmark.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +"""Full GEMM benchmark sweep driven through the Dispatcher bridge. + +Phases: + Phase 1: Compile all kernels (parallel, returns .so paths only -- no GPU) + Phase 2: Load problems (M, N, K shapes) + Phase 3: Benchmark via subprocess isolation, distributed across all visible + GPUs (one device-pinned worker per GPU, batched, fault-isolated) + +Tile Engine generates NO binaries here: it expands its sweep config into shared +``GemmKernelConfig`` objects and hands them to the dispatcher, which codegens + +compiles each into a .so. Each kernel runs in a disposable worker subprocess so +a GPU fault (or ctypes' inability to unload a .so) takes down only one worker. + +Unlike the serial-GPU design inherited from grouped_conv, Phase 3 here fans the +work out across every visible GPU in parallel: each device runs its own stream of +disposable worker subprocesses pinned with ``HIP_VISIBLE_DEVICES``, so an N-GPU +box benchmarks roughly N times faster while keeping per-batch fault isolation. + +Examples: + # Default: gemm_universal variant, its CI sweep config + example problems, + # auto-detect and use all visible GPUs. + python gemm_full_benchmark.py + + # Explicit variant + full sweep config on 4 GPUs: + python gemm_full_benchmark.py --variant gemm_universal \ + configs/default_config.json --devices 4 --csv out.csv + +When no config is given the driver uses the chosen variant's +``configs/default_ci_config.json`` (a small CI-sized sweep); +``configs/default_config.json`` is the full sweep, and the JSON used by nightly +tests is intended to drop into the same ``configs/`` directory. +""" + +import argparse +import csv +import json +import os +import queue +import re +import subprocess +import sys +import threading +import time +from pathlib import Path + +_THIS_DIR = Path(__file__).resolve().parent +_DISPATCHER_ROOT = _THIS_DIR.parents[2] / "dispatcher" +sys.path.insert(0, str(_DISPATCHER_ROOT / "python")) +sys.path.insert(0, str(_THIS_DIR)) + +from gemm_utils import setup_multiple_gemm_dispatchers, expand_sweep # noqa: E402 + +# Config layout. The bridged regular-GEMM path (gemm_universal) keeps its sweep +# configs in this op's flat ``configs/`` directory (matching the fmha/grouped_conv +# bridge convention): default_ci_config.json (small CI sweep), default_config.json +# (full sweep), user_provided_config.json, example_problems.json. The other, +# not-yet-bridged variants still live in their own per-variant ``configs/`` dirs; +# they are registered so ``--variant`` can select them once their bridge lands. +VARIANT_CONFIGS = { + "gemm_universal": "configs", + "gemm_multi_d": "gemm_multi_d/configs", + "gemm_preshuffle": "gemm_preshuffle/configs", + "grouped_gemm": "grouped_gemm/configs", +} +DEFAULT_VARIANT = "gemm_universal" +CI_CONFIG_NAME = "default_ci_config.json" +EXAMPLE_PROBLEMS_NAME = "example_problems.json" + +# Fallback problem set if a variant ships no example_problems.json. +DEFAULT_PROBLEMS = [ + {"M": 1024, "N": 1024, "K": 1024}, + {"M": 2048, "N": 2048, "K": 2048}, + {"M": 4096, "N": 4096, "K": 4096}, + {"M": 257, "N": 257, "K": 257}, +] + +SUPPORTED_DTYPES = ("fp16", "bf16") +# Row-major C only: ck_tile's universal GEMM rejects column-major C at build. +SUPPORTED_LAYOUTS = ("rcr", "rrr", "crr", "ccr") + + +def detect_devices(): + """Return a list of visible GPU id strings (best-effort).""" + env = os.environ.get("HIP_VISIBLE_DEVICES") or os.environ.get( + "CUDA_VISIBLE_DEVICES" + ) + if env: + ids = [d.strip() for d in env.split(",") if d.strip() != ""] + if ids: + return ids + try: + out = subprocess.check_output( + ["rocm-smi", "--showid"], stderr=subprocess.DEVNULL, text=True + ) + ids = sorted(set(re.findall(r"GPU\[(\d+)\]", out)), key=int) + if ids: + return ids + except Exception: + pass + try: + out = subprocess.check_output( + ["amd-smi", "list"], stderr=subprocess.DEVNULL, text=True + ) + ids = re.findall(r"^GPU:\s*(\d+)", out, re.MULTILINE) + if ids: + return ids + except Exception: + pass + return ["0"] + + +def resolve_devices(spec): + """Resolve --devices into a concrete list of device id strings. + + spec is None (auto: all visible), an int count, or a comma-list of ids. + A bare digit is a *count*, not an id; to target one specific id use the + comma form, e.g. "5,". + """ + detected = detect_devices() + if spec is None: + return detected + spec = str(spec).strip() + if "," in spec: + return [s.strip() for s in spec.split(",") if s.strip() != ""] + if spec.isdigit(): + n = int(spec) + if n <= 0: + return detected + # Treat a bare integer as a device *count*: take the first n detected ids. + # If the environment explicitly restricts visibility (HIP/CUDA_VISIBLE_DEVICES), + # do not invent additional ids beyond what's visible. + if len(detected) >= n: + return detected[:n] + if os.environ.get("HIP_VISIBLE_DEVICES") or os.environ.get("CUDA_VISIBLE_DEVICES"): + return detected + return [str(i) for i in range(n)] + return [spec] + + +def resolve_configs(args): + """Resolve positional configs -> concrete list of config paths.""" + if args.configs: + return args.configs + cfg = _THIS_DIR / VARIANT_CONFIGS[args.variant] / CI_CONFIG_NAME + return [str(cfg)] + + +def load_problems(path, variant): + if path: + with open(path) as f: + data = json.load(f) + return data["problems"] if isinstance(data, dict) else data + example = _THIS_DIR / VARIANT_CONFIGS[variant] / EXAMPLE_PROBLEMS_NAME + if example.exists(): + with open(example) as f: + data = json.load(f) + return data["problems"] if isinstance(data, dict) else data + return DEFAULT_PROBLEMS + + +def _run_batch_on_device(device_id, unit, args, worker_path, base_env): + """Run one (problem, kernel-batch) unit in a device-pinned subprocess. + + Returns (rows, lines, n_fail) where rows are dicts ready for the CSV writer, + lines are formatted strings to print, and n_fail counts failures. + """ + prob_idx, prob_dict, batch = unit + M, N, K = prob_dict["M"], prob_dict["N"], prob_dict["K"] + + items = [ + {"so_path": str(lib), "problem": prob_dict, "kernel_name": cfg.name} + for _, cfg, lib in batch + ] + payload = json.dumps( + {"items": items, "verify": args.verify, "verify_tol": args.verify_tol} + ) + + env = base_env.copy() + env["HIP_VISIBLE_DEVICES"] = str(device_id) + + rows, lines, n_fail = [], [], 0 + proc = None + try: + proc = subprocess.Popen( + [sys.executable, str(worker_path)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env=env, + ) + stdout_bytes, _ = proc.communicate( + input=payload.encode("utf-8"), + timeout=args.kernel_timeout * len(batch), + ) + + reported = set() + for line in stdout_bytes.decode("utf-8").strip().split("\n"): + if not line: + continue + try: + result = json.loads(line) + except json.JSONDecodeError: + lines.append(f" [gpu{device_id}] Warning: bad result line: {line[:50]}") + n_fail += 1 + continue + bidx = result.get("idx", 0) + _, cfg, _ = batch[bidx] + reported.add(bidx) + if result.get("ok", False): + status = "OK" if result.get("non_zero", 0) > 0 else "ZERO" + mismatch = False + if args.verify and "verified" in result: + if result["verified"]: + status = "VERIFY" + else: + status = "MISMATCH" + mismatch = True + extra = ( + f" rel={result['max_rel']:.2e}" if "max_rel" in result else "" + ) + lines.append( + f" [gpu{device_id}] {cfg.name:<58} {result['ms']:>10.3f} " + f"{result['tflops']:>10.2f} {status:>8}{extra}" + ) + rows.append( + { + "kernel": cfg.name, + "problem_idx": prob_idx, + "M": M, + "N": N, + "K": K, + "device": device_id, + "latency_ms": result["ms"], + "tflops": result["tflops"], + "non_zero": result.get("non_zero", 0), + "max_rel": result.get("max_rel", ""), + "verified": result.get("verified", ""), + } + ) + if mismatch: + n_fail += 1 + else: + lines.append(f" [gpu{device_id}] {cfg.name:<58} FAILED") + lines.append(f" Error: {result.get('error', 'unknown')[:100]}") + n_fail += 1 + + missing = set(range(len(batch))) - reported + if missing or proc.returncode != 0: + if proc.returncode != 0: + lines.append(f" [gpu{device_id}] worker exited code {proc.returncode}") + for idx in sorted(missing): + _, cfg, _ = batch[idx] + lines.append(f" [gpu{device_id}] {cfg.name:<58} MISSING (crash)") + n_fail += len(missing) + + except subprocess.TimeoutExpired: + lines.append(f" [gpu{device_id}] batch timeout ({len(batch)} kernels)") + try: + proc.kill() + proc.communicate(timeout=5) + except Exception: + pass + n_fail += len(batch) + except Exception as e: + lines.append(f" [gpu{device_id}] batch error: {e}") + try: + if proc and proc.poll() is None: + proc.kill() + except Exception: + pass + n_fail += len(batch) + + return rows, lines, n_fail + + +def main(): + parser = argparse.ArgumentParser(description="GEMM Benchmark Sweep (via Dispatcher)") + parser.add_argument( + "configs", + nargs="*", + help="TE sweep config JSON files (default: variant's default_ci_config.json)", + ) + parser.add_argument( + "--variant", + default=DEFAULT_VARIANT, + choices=tuple(VARIANT_CONFIGS), + help="GEMM variant (selects the configs/ directory)", + ) + parser.add_argument("--arch", default="gfx942") + parser.add_argument( + "--dtype", + default="fp16", + choices=SUPPORTED_DTYPES, + help=f"Input dtype (supported: {', '.join(SUPPORTED_DTYPES)})", + ) + parser.add_argument( + "--layout", + default="rcr", + choices=SUPPORTED_LAYOUTS, + help=f"A/B/C layout (supported: {', '.join(SUPPORTED_LAYOUTS)})", + ) + parser.add_argument("--problems", default=None, help="JSON file of M,N,K problems") + parser.add_argument("--csv", type=str, default="gemm_results.csv") + parser.add_argument("--workers", type=int, default=8, help="Parallel build workers") + parser.add_argument( + "--devices", + default=None, + help="GPUs to use: int count (e.g. 4) or comma-list of ids (e.g. 0,2,5); " + "for one specific id use the comma form (e.g. 5,) since a bare digit is " + "a count; default auto-detects all visible", + ) + parser.add_argument( + "--batch-size", + type=int, + default=20, + help="Kernels per subprocess (overhead vs fault isolation)", + ) + parser.add_argument( + "--kernel-timeout", type=int, default=30, help="Per-kernel timeout (s)" + ) + parser.add_argument( + "--max-kernels", type=int, default=0, help="Limit to first N kernels (0=all)" + ) + parser.add_argument( + "--verify", + action="store_true", + help="Check each kernel's output against an fp32 numpy reference " + "(global max|out-ref|/max|ref|); a mismatch counts as a failure", + ) + parser.add_argument( + "--verify-tol", + type=float, + default=2e-2, + help="Relative tolerance for --verify (default 2e-2, suits fp16)", + ) + args = parser.parse_args() + + config_paths = resolve_configs(args) + devices = resolve_devices(args.devices) + + # ======================================================================== + # Phase 1: Compile kernels (parallel, no GPU) + # ======================================================================== + print(f"\n{'=' * 80}") + print("Phase 1: Compile kernels") + print(f"{'=' * 80}") + print(f" Variant: {args.variant}") + print(f" Configs: {', '.join(config_paths)}") + + all_configs = [] + for cfg_path in config_paths: + all_configs.extend( + expand_sweep(cfg_path, args.arch, dtype=args.dtype, layout=args.layout) + ) + + if args.max_kernels > 0: + all_configs = all_configs[: args.max_kernels] + + print(f" Expanded configs: {len(all_configs)}") + print(f" Build workers: {args.workers}") + + t0 = time.perf_counter() + # CRITICAL: returns Path objects only, does NOT load any .so. + lib_paths = setup_multiple_gemm_dispatchers( + all_configs, verbose=True, max_workers=args.workers + ) + build_time = time.perf_counter() - t0 + + built_kernels = [ + (cfg, lib) for cfg, lib in zip(all_configs, lib_paths) if lib is not None + ] + + # Dedupe by .so path (distinct configs can map to the same physical kernel). + seen_libs = set() + unique_kernels = [] + duplicate_count = 0 + for cfg, lib in built_kernels: + lib_key = str(lib.resolve()) + if lib_key not in seen_libs: + seen_libs.add(lib_key) + unique_kernels.append((cfg, lib)) + else: + duplicate_count += 1 + built_kernels = unique_kernels + + print( + f"\n Built {len(all_configs)} configs -> {len(built_kernels)} unique kernels " + f"({duplicate_count} duplicates filtered) in {build_time:.0f}s" + ) + + if not built_kernels: + print(" ERROR: No kernels built successfully") + return 1 + + # ======================================================================== + # Phase 2: Load problems + # ======================================================================== + print(f"\n{'=' * 80}") + print("Phase 2: Load test problems") + print(f"{'=' * 80}") + + problems = load_problems(args.problems, args.variant) + print(f" Problems: {len(problems)}") + print( + f" Total measurements: {len(built_kernels)} x {len(problems)} = " + f"{len(built_kernels) * len(problems)}" + ) + + # ======================================================================== + # Phase 3: Benchmark across all visible GPUs (subprocess isolation, batched) + # ======================================================================== + print(f"\n{'=' * 80}") + print("Phase 3: Benchmark (multi-GPU, subprocess isolation, batched)") + print(f"{'=' * 80}") + print(f" Devices: {len(devices)} -> {', '.join(devices)}") + print(f" Batch size: {args.batch_size} kernels per subprocess") + print(f" Timeout: {args.kernel_timeout}s per kernel\n") + + csv_path = Path(args.csv) + csv_fields = [ + "kernel", + "problem_idx", + "M", + "N", + "K", + "device", + "latency_ms", + "tflops", + "non_zero", + "max_rel", + "verified", + ] + csv_file = open(csv_path, "w", newline="") + writer = csv.DictWriter(csv_file, fieldnames=csv_fields) + writer.writeheader() + + worker_path = _THIS_DIR / "run_one_gemm_kernel.py" + base_env = os.environ.copy() + base_env["GEMM_PYPATH"] = os.pathsep.join( + [str(_DISPATCHER_ROOT / "python"), str(_THIS_DIR)] + ) + + # Build a single work queue of (prob_idx, prob_dict, kernel-batch) units and + # fan them out across device-pinned worker threads. + work_q = queue.Queue() + for prob_idx, prob in enumerate(problems): + prob_dict = {"M": int(prob["M"]), "N": int(prob["N"]), "K": int(prob["K"])} + for start in range(0, len(built_kernels), args.batch_size): + end = min(start + args.batch_size, len(built_kernels)) + batch = [ + (start + j, cfg, lib) + for j, (cfg, lib) in enumerate(built_kernels[start:end]) + ] + work_q.put((prob_idx, prob_dict, batch)) + + io_lock = threading.Lock() + stats = {"measurements": 0, "failures": 0} + bench_t0 = time.perf_counter() + + def device_thread(device_id): + while True: + try: + unit = work_q.get_nowait() + except queue.Empty: + return + rows, lines, n_fail = _run_batch_on_device( + device_id, unit, args, worker_path, base_env + ) + with io_lock: + for ln in lines: + print(ln) + for row in rows: + writer.writerow(row) + csv_file.flush() + stats["measurements"] += len(rows) + stats["failures"] += n_fail + work_q.task_done() + + threads = [ + threading.Thread(target=device_thread, args=(d,), daemon=True) for d in devices + ] + for t in threads: + t.start() + for t in threads: + t.join() + + bench_time = time.perf_counter() - bench_t0 + csv_file.close() + + # ======================================================================== + # Summary + # ======================================================================== + print(f"\n{'=' * 80}") + print("BENCHMARK COMPLETE") + print(f"{'=' * 80}") + print(f" Build time: {build_time:.0f}s") + print(f" Benchmark time: {bench_time:.0f}s") + print(f" Total time: {build_time + bench_time:.0f}s") + print(f" Devices used: {len(devices)}") + print(f" Successful measurements: {stats['measurements']}") + print(f" Failed measurements: {stats['failures']}") + print(f" Output: {csv_path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py new file mode 100644 index 000000000000..54e138198a96 --- /dev/null +++ b/projects/composablekernel/tile_engine/ops/gemm/run_one_gemm_kernel.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +"""Worker script for running GEMM kernels in an isolated subprocess. + +Mirrors grouped_conv's run_one_grouped_conv_kernel.py: +- Receives kernel config + problem via stdin as JSON +- Loads the .so library ONLY inside this subprocess +- Outputs timing results as JSON to stdout (one line per kernel, flushed) +- A GPU fault kills only this process; the parent driver can continue + +Input JSON format: + Single: {"so_path": "...", "problem": {"M":.., "N":.., "K":..}, "kernel_name": "..."} + Batch: {"items": [{"so_path": "...", "problem": {...}, "kernel_name": "..."}, ...]} + +Optional top-level keys ``verify`` (bool) and ``verify_tol`` (float) enable an +fp32 numpy reference check; when set, each OK result also carries ``verified`` +and ``max_rel``. + +Output JSON format (one line per kernel): + {"idx": 0, "ok": true, "ms": 0.123, "tflops": 456.7, "non_zero": 1, "kernel": "..."} + {"idx": 0, "ok": true, ..., "verified": true, "max_rel": 3.1e-4} # with --verify + {"idx": 1, "ok": false, "error": "...", "kernel": "..."} +""" + +import json +import os +import sys + +# Add dispatcher python paths from environment (os.pathsep-separated). +gemm_pypath = os.environ.get("GEMM_PYPATH", "") +if gemm_pypath: + for p in gemm_pypath.split(os.pathsep): + if p and p not in sys.path: + sys.path.insert(0, p) + +from gemm_utils import GemmProblem, GpuGemmRunner # noqa: E402 +import numpy as np # noqa: E402 + + +def _run_one(idx, so_path, prob_dict, kernel_name, verify=False, verify_tol=2e-2): + """Run a single kernel and emit its result as one JSON line. + + When ``verify`` is set, the kernel output is checked against an fp32 numpy + reference (``A @ B``) using the global relative metric + ``max|out - ref| / max|ref|``; the emitted ``verified`` field then reflects + correctness, not just liveness (``non_zero``). + """ + try: + problem = GemmProblem.from_dict(prob_dict) + + # Cache host matrices per shape so batch mode doesn't regenerate huge inputs per kernel. + cache = getattr(_run_one, "_ab_cache", {}) + key = (problem.M, problem.N, problem.K) + if key not in cache: + rng = np.random.RandomState(42) + cache[key] = ( + (rng.randn(problem.M, problem.K) * 0.1).astype(np.float32), + (rng.randn(problem.K, problem.N) * 0.1).astype(np.float32), + ) + _run_one._ab_cache = cache + A, B = cache[key] + + # CRITICAL: load the library ONLY inside this subprocess. + runner = GpuGemmRunner(lib_path=so_path) + result = runner.run(A, B, problem) + + if result.success: + non_zero = ( + int(np.count_nonzero(result.output)) + if result.output is not None + else 0 + ) + out = { + "idx": idx, + "ok": True, + "ms": result.time_ms, + "tflops": result.tflops, + "non_zero": non_zero, + "kernel": kernel_name, + } + if verify: + ref = A.astype(np.float32) @ B.astype(np.float32) + got = result.output.astype(np.float32) + denom = float(np.max(np.abs(ref))) or 1.0 + max_rel = float(np.max(np.abs(got - ref)) / denom) + out["max_rel"] = max_rel + out["verified"] = bool(max_rel <= verify_tol) + print(json.dumps(out), flush=True) + else: + print( + json.dumps( + { + "idx": idx, + "ok": False, + "error": f"kernel returned status {result.status}", + "kernel": kernel_name, + } + ), + flush=True, + ) + + except Exception as e: + print( + json.dumps( + {"idx": idx, "ok": False, "error": str(e), "kernel": kernel_name} + ), + flush=True, + ) + + +def main(): + """Read JSON from stdin, run kernel(s), output results.""" + try: + d = json.loads(sys.stdin.buffer.read()) + except Exception as e: + print( + json.dumps({"idx": 0, "ok": False, "error": f"JSON parse error: {e}"}), + flush=True, + ) + sys.exit(1) + + verify = bool(d.get("verify", False)) + verify_tol = float(d.get("verify_tol", 2e-2)) + + if "items" in d: + for i, item in enumerate(d["items"]): + _run_one( + i, + item["so_path"], + item["problem"], + item.get("kernel_name", "unknown"), + verify=verify, + verify_tol=verify_tol, + ) + else: + _run_one( + 0, + d["so_path"], + d["problem"], + d.get("kernel_name", "unknown"), + verify=verify, + verify_tol=verify_tol, + ) + + +if __name__ == "__main__": + main()