Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <memory>
#include <sstream>
#include <string>
#include <type_traits>

#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/registry.hpp"
Expand Down Expand Up @@ -65,15 +66,74 @@ int dispatcher_initialize()
return 0; // Already initialized
}

// Create kernel key from the force-included kernel header
// Create kernel key from the force-included kernel header.
//
// The GEMM_KEY_* macros are emitted by the codegen into the force-included
// header (see unified_gemm_codegen.py, CK_TILE_SINGLE_KERNEL_INCLUDE block).
// Building the key from them makes the registry entry truthful: it reflects
// THIS kernel's real dtypes/layouts/tile/traits instead of a hard-coded
// fp16/rcr/128x128x32 default. Enum fields use the string_to_* helpers from
// kernel_key.hpp, whose accepted strings match the codegen's emitted values
// byte-for-byte.
KernelKey key;
#ifdef GEMM_KEY_DTYPE_A
key.signature.dtype_a = string_to_dtype(GEMM_KEY_DTYPE_A);
key.signature.dtype_b = string_to_dtype(GEMM_KEY_DTYPE_B);
key.signature.dtype_c = string_to_dtype(GEMM_KEY_DTYPE_C);
key.signature.dtype_acc = string_to_dtype(GEMM_KEY_DTYPE_ACC);
key.signature.layout_a = string_to_layout(GEMM_KEY_LAYOUT_A);
key.signature.layout_b = string_to_layout(GEMM_KEY_LAYOUT_B);
key.signature.layout_c = string_to_layout(GEMM_KEY_LAYOUT_C);
key.signature.transpose_a = false;
key.signature.transpose_b = false;
key.signature.grouped = (GEMM_KEY_GROUPED != 0);
key.signature.split_k = GEMM_KEY_SPLIT_K;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.signature.structured_sparsity = false;

key.algorithm.tile_shape = {GEMM_KEY_TILE_M, GEMM_KEY_TILE_N, GEMM_KEY_TILE_K};
key.algorithm.wave_shape = {GEMM_KEY_WAVE_M, GEMM_KEY_WAVE_N, GEMM_KEY_WAVE_K};
key.algorithm.warp_tile_shape = {GEMM_KEY_WARP_TILE_M, GEMM_KEY_WARP_TILE_N, GEMM_KEY_WARP_TILE_K};
key.algorithm.pipeline = string_to_pipeline(GEMM_KEY_PIPELINE);
key.algorithm.scheduler = string_to_scheduler(GEMM_KEY_SCHEDULER);
key.algorithm.epilogue = string_to_epilogue(GEMM_KEY_EPILOGUE);
key.algorithm.block_size = GEMM_KEY_BLOCK_SIZE;
key.algorithm.double_buffer = (GEMM_KEY_DOUBLE_BUFFER != 0);
key.algorithm.persistent = (GEMM_KEY_PERSISTENT != 0);
key.algorithm.preshuffle = (GEMM_KEY_PRESHUFFLE != 0);
key.algorithm.transpose_c = (GEMM_KEY_TRANSPOSE_C != 0);
key.algorithm.num_wave_groups = GEMM_KEY_NUM_WAVE_GROUPS;
// pad_m/n/k participate in both the key's hash/equality and the kernel
// name, so they must be derived from the codegen macros too -- otherwise a
// kernel built with padding disabled would register under a key claiming
// pad=true and disagree with its own name.
key.algorithm.pad_m = (GEMM_KEY_PAD_M != 0);
key.algorithm.pad_n = (GEMM_KEY_PAD_N != 0);
key.algorithm.pad_k = (GEMM_KEY_PAD_K != 0);
key.gfx_arch = GFX_ARCH;
#else
// Fallback default for headers generated before GEMM_KEY_* macros existed
// (fp16 / rcr / compv4-cshuffle-intrawave, 128x128x32). The macro path
// above is the source of truth for any freshly generated kernel.
key.signature.dtype_a = DataType::FP16;
key.signature.dtype_b = DataType::FP16;
key.signature.dtype_c = DataType::FP16;
key.signature.dtype_acc = DataType::FP32;
key.signature.layout_a = LayoutTag::RowMajor;
key.signature.layout_b = LayoutTag::ColMajor;
key.signature.layout_c = LayoutTag::RowMajor;
// Derive A/B/C layouts from the force-included kernel's own layout types
// instead of hardcoding rcr. The dispatcher's supports() gate is layout-aware
// (it only constrains a dimension that an operand's inner axis maps to), so a
// wrong key layout makes it reject valid problems -- e.g. a crr kernel does not
// gate K, but with a hardcoded rcr key supports() would apply rcr's K-gate and
// reject TileK=192 problems that Old-TE runs. ALayout/BLayout/CLayout are the
// global aliases exported by the kernel header under CK_TILE_SINGLE_KERNEL_INCLUDE.
using RowMajorLayout = ck_tile::tensor_layout::gemm::RowMajor;
key.signature.layout_a =
std::is_same_v<ALayout, RowMajorLayout> ? LayoutTag::RowMajor : LayoutTag::ColMajor;
key.signature.layout_b =
std::is_same_v<BLayout, RowMajorLayout> ? LayoutTag::RowMajor : LayoutTag::ColMajor;
key.signature.layout_c =
std::is_same_v<CLayout, RowMajorLayout> ? LayoutTag::RowMajor : LayoutTag::ColMajor;
key.signature.transpose_a = false;
key.signature.transpose_b = false;
key.signature.grouped = false;
Expand All @@ -95,6 +155,7 @@ int dispatcher_initialize()
key.algorithm.transpose_c = false;
key.algorithm.num_wave_groups = 1;
key.gfx_arch = GFX_ARCH;
#endif // GEMM_KEY_DTYPE_A

// Register kernel using types from force-included header
auto kernel =
Expand Down Expand Up @@ -310,10 +371,40 @@ int dispatcher_run_gemm(
}

/**
* Get kernel information
* Get kernel information (legacy single-kernel ABI).
*
* Returns the compile-time KERNEL_NAME of the force-included kernel header.
* Kept for backward compatibility with one-kernel-per-.so callers.
*/
const char* dispatcher_get_kernel_name() { return KERNEL_NAME; }

/**
* Get the name of the kernel at a given registry index (multi-kernel ABI).
*
* Mirrors the conv/fmha ctypes libs: copies the index-th registered kernel's
* name into the caller-provided buffer so one .so can report a whole batch and
* be selected by name at runtime. Returns 0 on success, -1 on bad args or
* out-of-range index.
*/
int dispatcher_get_kernel_name_at(int index, char* buffer, int buffer_size)
{
if(!buffer || buffer_size <= 0)
{
return -1;
}

auto kernels = Registry::instance().get_all();
if(index < 0 || index >= static_cast<int>(kernels.size()))
{
return -1;
}

std::string name = kernels[index]->get_name();
std::strncpy(buffer, name.c_str(), static_cast<size_t>(buffer_size) - 1);
buffer[buffer_size - 1] = '\0';
return 0;
}

/**
* Initialize dispatcher (alias)
*/
Expand Down
25 changes: 23 additions & 2 deletions projects/composablekernel/dispatcher/codegen/codegen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class CommonTypeMappings:
"fp8": "fp8_t",
"bf8": "bf8_t",
"int8": "int8_t",
"int32": "int32_t",
}

DTYPE_TO_CK_QUALIFIED = {
Expand All @@ -127,6 +128,7 @@ class CommonTypeMappings:
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int8": "int8_t",
"int32": "int32_t",
}

DTYPE_TO_DISPATCHER = {
Expand All @@ -136,6 +138,7 @@ class CommonTypeMappings:
"fp8": "DataType::FP8",
"bf8": "DataType::BF8",
"int8": "DataType::INT8",
"int32": "DataType::INT32",
}

# GEMM-specific layout mappings ("r"/"c" for row/column major).
Expand Down Expand Up @@ -202,8 +205,26 @@ class CommonTypeMappings:

@staticmethod
def get_output_dtype(dtype: str) -> str:
"""Get output datatype (fp8/bf8 -> fp16)."""
return "fp16" if dtype in ("fp8", "bf8") else dtype
"""Get output (C) datatype for an A/B element dtype.

Low-precision float inputs accumulate into and store as fp16
(fp8/bf8 -> fp16); int8 stores its int32 accumulator (int8 -> int32).
Everything else stores in its own dtype.
"""
if dtype in ("fp8", "bf8"):
return "fp16"
if dtype == "int8":
return "int32"
return dtype

@staticmethod
def get_acc_dtype(dtype: str) -> str:
"""Get accumulator datatype for an A/B element dtype.

Integer GEMM accumulates in int32; every float dtype accumulates in
fp32.
"""
return "int32" if dtype == "int8" else "fp32"


# ============================================================================
Expand Down
Loading
Loading