From 3b6b7c6803df32393cbd3a04b58856ef4d38741f Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:08:53 -0400 Subject: [PATCH 01/36] RC-01: fix cluster::sync() to use cluster user barrier (-3) instead of WG barrier (-1) s_barrier_signal/wait -1 only syncs waves within a single workgroup. Cluster-wide CGA sync requires barrier ID -3 (cluster user barrier) per MI400 ISA Section 4.3.6. Verified via ISA dump. --- include/ops/warp/cluster/cluster.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ops/warp/cluster/cluster.cuh b/include/ops/warp/cluster/cluster.cuh index 48d2902c..7d72547f 100644 --- a/include/ops/warp/cluster/cluster.cuh +++ b/include/ops/warp/cluster/cluster.cuh @@ -44,14 +44,14 @@ __device__ __host__ __forceinline__ constexpr uint32_t mask( } /** - * @brief Cluster-wide split barrier. + * @brief Cluster-wide barrier (signal + wait on cluster user barrier -3). * - * Outside a CGA cluster this lowers to a workgroup-wide `sync::sync()`. Inside - * a cluster the same `s_barrier_signal -1 / s_barrier_wait -1` pair extends to - * every workgroup in the cluster by hardware-managed forwarding. + * Barrier -3 syncs across all workgroups in a CGA cluster. + * Outside a cluster, use `sync::sync()` (barrier -1, workgroup-only). */ __device__ __forceinline__ void sync() { - ::kittens::sync::sync(); + __builtin_amdgcn_s_barrier_signal(-3); + __builtin_amdgcn_s_barrier_wait(-3); } } // namespace cluster From 3891e96ca525a0fc6aec7ee9e257931092f04538 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:08:58 -0400 Subject: [PATCH 02/36] RC-02: fix mfma323232 const C aliasing violation when D != C First MFMA was writing result into C (stripping const via cast), then second MFMA read the modified C. Use a local accumulator instead so C is never mutated. Compiler optimizes it away when D aliases C. Bug confirmed on gfx950 MI350X: C corrupted from 42.0 to 58.0. --- include/ops/warp/register/tile/mma.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ops/warp/register/tile/mma.cuh b/include/ops/warp/register/tile/mma.cuh index c4aca036..4e7b79f3 100644 --- a/include/ops/warp/register/tile/mma.cuh +++ b/include/ops/warp/register/tile/mma.cuh @@ -124,7 +124,7 @@ __device__ static inline void mfma323232( float2 (&D)[8], typedef __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16 bf16x8_t; typedef __attribute__((__vector_size__(16 * sizeof(float)))) float floatx16_t; - *(floatx16_t*)C = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + floatx16_t acc = __builtin_amdgcn_mfma_f32_32x32x16_bf16( *(bf16x8_t*)A, *(bf16x8_t*)B, *(floatx16_t*)C, @@ -134,7 +134,7 @@ __device__ static inline void mfma323232( float2 (&D)[8], *(floatx16_t*)D = __builtin_amdgcn_mfma_f32_32x32x16_bf16( *(bf16x8_t*)(A + 4), *(bf16x8_t*)(B + 4), - *(floatx16_t*)C, + acc, 0, 0, 0 ); } From c02694f62212f2daa7222a55913f046115254759 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:09:03 -0400 Subject: [PATCH 03/36] RC-03: add assert to init_barrier validating count range count==0 wraps to pending=0xFFFF causing silent hang. count>65535 wraps silently. Assert catches both at runtime. --- include/ops/warp/sync/barrier.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ops/warp/sync/barrier.cuh b/include/ops/warp/sync/barrier.cuh index 0a6ef74d..27623038 100644 --- a/include/ops/warp/sync/barrier.cuh +++ b/include/ops/warp/sync/barrier.cuh @@ -191,6 +191,7 @@ struct alignas(8) barrier_lds { uint64_t state; }; /// @brief Initialize an LDS barrier cell to expect `count` arrivals per phase. __device__ __forceinline__ void init_barrier(uint64_t* bar, uint32_t count) { + assert(count > 0 && count <= 0xFFFF); // pending = count - 1, phase = 0, init_count = count - 1. const uint32_t pending = count - 1; const uint32_t init_cnt = count - 1; From b8dcb10b10579a5ea2b268f71251325dee1d2d44 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:09:08 -0400 Subject: [PATCH 04/36] =?UTF-8?q?RC-04:=20fix=20packed=5Fshfl=5Fdown=20wid?= =?UTF-8?q?th=3D64=20on=20wave-32=20=E2=80=94=20use=20WARP=5FTHREADS?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hardcoded width=64 causes ds_bpermute to wrap around on wave-32, reading from lane (L+delta)%32 instead of returning identity. This doubles every value in tree reductions (sum=1056 instead of 528). Verified on the gfx1250 simulator. --- include/common/util.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/common/util.cuh b/include/common/util.cuh index 8b007f40..7c23a8bb 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -201,7 +201,7 @@ __device__ static inline T packed_shfl_down(uint64_t mask, const T &f, int delta *reinterpret_cast(&f)}; } - u.ui = __shfl_down_sync(mask, u.ui, delta, 64); + u.ui = __shfl_down_sync(mask, u.ui, delta, WARP_THREADS); if constexpr (std::is_same_v) { return *reinterpret_cast(&u.bf162.x); // Extract single bf16 from the .x component } else { From 4d9afaff5da11bc0788ffb51eb821655db955a2c Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:09:13 -0400 Subject: [PATCH 05/36] SR-01: add M/N/K alignment asserts in grid() to catch silent truncation Non-aligned dimensions silently produce wrong results: grid() truncates via integer division, store_acc writes partial tiles. Assert fires for M=100 on the gfx1250 simulator, passes for M=64. --- kernels/gemm/bf16fp32/gfx1250/common.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/common.h b/kernels/gemm/bf16fp32/gfx1250/common.h index 6eba740a..394565ec 100644 --- a/kernels/gemm/bf16fp32/gfx1250/common.h +++ b/kernels/gemm/bf16fp32/gfx1250/common.h @@ -36,7 +36,12 @@ struct gemm_globals { int M() const { return a.rows(); } int N() const { return c.cols(); } int K() const { return a.cols(); } - dim3 grid() const { return dim3(M() / BLOCK_M, N() / BLOCK_N); } + dim3 grid() const { + assert(M() % BLOCK_M == 0 && "M must be a multiple of BLOCK_M"); + assert(N() % BLOCK_N == 0 && "N must be a multiple of BLOCK_N"); + assert(K() % K_STEP == 0 && "K must be a multiple of K_STEP"); + return dim3(M() / BLOCK_M, N() / BLOCK_N); + } dim3 block() const { return dim3(NUM_THREADS); } size_t dynamic_shared_memory() const { return kittens::MAX_SHARED_MEMORY; } }; From 524771fb398c16227ed8d68bc94334d0fa3ecdf3 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:14:45 -0400 Subject: [PATCH 06/36] SR-02: add shared_allocator overflow check via assert All three allocate variants now assert ptr stays within bounds. Regular allocate() checks against MAX_SHARED_MEMORY. Segmented allocate_in() checks against the segment's end boundary. Unified base pointer storage across UDNA1 and CDNA paths. Uses device-side assert() which prints file:line and traps. Verified on gfx950 MI350X: pointer-comparison assert fires correctly when overflow condition is runtime-determined. --- include/common/util.cuh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/include/common/util.cuh b/include/common/util.cuh index 7c23a8bb..f90e5cd2 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -332,9 +332,7 @@ template concept all = is_segment::value; template struct shared_allocator { int *ptr; -#ifdef KITTENS_UDNA1 int *base; -#endif private: // Recursive template to generate N-dimensional array type @@ -366,24 +364,21 @@ struct shared_allocator { * @brief Construct a new shared allocator using a pointer to extern shared memory. * @param[in] _ptr Pointer to the start of the extern shared memory. */ -#ifdef KITTENS_UDNA1 __device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) {} -#else - __device__ shared_allocator(int *_ptr): ptr(_ptr) {} -#endif /** * @brief Allocate shared memory for a single instance or N-dimensional array of type A. * @tparam A The type of the object to allocate. * @tparam dims... A list of dimensions for the N-dimensional array. * @return Reference to the allocated object. */ - template + template __device__ inline variadic_array_t& allocate() { // static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation"); align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); ptr += sizeof(at)/sizeof(int); + assert(ptr <= base + MAX_SHARED_MEMORY / sizeof(int)); return *p; } /** @@ -393,13 +388,14 @@ struct shared_allocator { * @tparam dims... A list of dimensions for the N-dimensional array. * @return Reference to the allocated object. */ - template + template __device__ inline variadic_array_t& allocate() { // static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation"); align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); ptr += sizeof(at)/sizeof(int); + assert(ptr <= base + MAX_SHARED_MEMORY / sizeof(int)); return *p; } @@ -426,6 +422,8 @@ struct shared_allocator { using at = variadic_array_t; at* p = reinterpret_cast(ptr); ptr += sizeof(at) / sizeof(int); + constexpr int seg_end = (SEG::index + 1) * LDS_SEGMENT_BYTES / sizeof(int); + assert(ptr <= base + seg_end); return *p; } #endif // KITTENS_UDNA1 From dc39999a30e243bec3edd799a3899f8e1bd2f4de Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 17:55:56 -0400 Subject: [PATCH 07/36] SR-03: add per-segment watermarks to shared_allocator allocate_in() used a single ptr for all segments. Allocating into segment 0 after segment 1 would silently place data in segment 1 because ptr can only advance forward. Now each segment tracks its own cursor via seg_ptr[LDS_NUM_SEGMENTS], allowing interleaved segment allocations. Verified on the gfx1250 simulator: gemm_segment produces correct results. --- include/common/util.cuh | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/include/common/util.cuh b/include/common/util.cuh index f90e5cd2..7fe0a89c 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -333,6 +333,9 @@ template struct shared_allocator { int *ptr; int *base; +#ifdef KITTENS_UDNA1 + int *seg_ptr[LDS_NUM_SEGMENTS]; +#endif private: // Recursive template to generate N-dimensional array type @@ -364,7 +367,12 @@ struct shared_allocator { * @brief Construct a new shared allocator using a pointer to extern shared memory. * @param[in] _ptr Pointer to the start of the extern shared memory. */ - __device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) {} + __device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) { +#ifdef KITTENS_UDNA1 + for (int i = 0; i < LDS_NUM_SEGMENTS; i++) + seg_ptr[i] = base + i * (LDS_SEGMENT_BYTES / (int)sizeof(int)); +#endif + } /** * @brief Allocate shared memory for a single instance or N-dimensional array of type A. * @tparam A The type of the object to allocate. @@ -415,15 +423,17 @@ struct shared_allocator { template requires ducks::segment_tag::all __device__ inline variadic_array_t& allocate_in() { - int* target = base + (SEG::byte_offset / sizeof(int)); - // If we've already allocated past the requested segment, keep - // packing where we are; otherwise jump forward to the segment. - if (ptr < target) ptr = target; + constexpr int idx = SEG::index; + if constexpr (default_alignment > 0) { + uint64_t p = reinterpret_cast(seg_ptr[idx]); + if (p % default_alignment != 0) + seg_ptr[idx] = (int*)(p + (default_alignment - (p % default_alignment))); + } using at = variadic_array_t; - at* p = reinterpret_cast(ptr); - ptr += sizeof(at) / sizeof(int); - constexpr int seg_end = (SEG::index + 1) * LDS_SEGMENT_BYTES / sizeof(int); - assert(ptr <= base + seg_end); + at* p = reinterpret_cast(seg_ptr[idx]); + seg_ptr[idx] += sizeof(at) / sizeof(int); + constexpr int seg_end = (idx + 1) * LDS_SEGMENT_BYTES / sizeof(int); + assert(seg_ptr[idx] <= base + seg_end); return *p; } #endif // KITTENS_UDNA1 From 2b9406a43cacebf69af7d12dad046928439098df Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:01:01 -0400 Subject: [PATCH 08/36] SR-04: document load_b32 as unpadded-only load_b32 does not apply padding offsets (unlike load_b128 which takes a Pad template param). Adding Pad support was attempted but blocked by include ordering (lds_nopad defined in global_to_shared.cuh which is included after shared_to_register.cuh). Added comment directing users to load_b128 for padded layouts. --- include/ops/warp/memory/tile/shared_to_register.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index 8c905aaf..5e3561c9 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -783,6 +783,7 @@ __device__ inline void load_b32( rt_bf& dst, const bf16* __restrict__ warp_lds_base) { + // Unpadded only — use load_b128 for padded LDS layouts. constexpr int height = WARP_M / detail::GFX1250_SUB_ROWS; constexpr int width = WARP_K / detail::GFX1250_SUB_COLS; constexpr int subs_per_row = WARP_K / detail::GFX1250_SUB_COLS; From 63eed26d2794e58f7cb754cc819a19292d4056ac Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:03:44 -0400 Subject: [PATCH 09/36] SR-05: add static_assert for 16-byte pad alignment in load_b128 ds_load_b128 requires 16-byte aligned LDS addresses. A custom lds_padded where M * sizeof(bf16) is not a multiple of 16 would silently misalign loads. static_assert catches this at compile time. --- include/ops/warp/memory/tile/shared_to_register.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index 5e3561c9..2a418543 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -729,6 +729,8 @@ __device__ inline void load_b128( rt_bf& dst, const bf16* __restrict__ warp_lds_base) { + static_assert(Pad::amount == 0 || Pad::amount * sizeof(bf16) % 16 == 0, + "Pad amount must be a multiple of 16 bytes for ds_load_b128 alignment"); constexpr int height = WARP_M / detail::GFX1250_SUB_ROWS; constexpr int width = WARP_K / detail::GFX1250_SUB_COLS; constexpr int subs_per_row = WARP_K / detail::GFX1250_SUB_COLS; From d1b209987aa37545768b353f69b7c2d7325f9bf9 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:08:41 -0400 Subject: [PATCH 10/36] SR-06: document that g2s load functions require tile-aligned dimensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No bounds clamping on DRAM reads — out-of-bounds tile coordinates read from invalid memory silently. Added doc comment to load() and load_async() noting caller must ensure alignment. Enforced at dispatch level by SR-01's M/N/K alignment asserts. --- include/ops/warp/memory/tile/global_to_shared.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index 6f18f7f7..364e23dc 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -501,6 +501,8 @@ __device__ __forceinline__ int subtile_flat(int flat) { * Plain `global_load` -> VGPR -> `ds_store` path. Use this when no async * intrinsic is available or for correctness baselines. The `Pad` parameter * controls the per-element LDS placement; pass `lds_nopad` for flat layouts. + * + * Caller must ensure matrix dimensions are multiples of ROWS/COLS (no bounds clamping). */ template> @@ -536,6 +538,8 @@ __device__ inline void load(T* __restrict__ lds_dst, const GL& src, const COORD& * issues one 16-byte transfer; the warp covers `8 * N_THREADS` elements per * iteration. Drain with `kittens::sync::wait_async()` before consuming. * + * Caller must ensure matrix dimensions are multiples of ROWS/COLS (no bounds clamping). + * * @tparam Pad LDS padding descriptor. * @tparam ROWS,COLS Tile shape (elements). * @tparam N_THREADS Number of threads participating in the load. From 5364ced2d388da5d1d339685599308daf5a73c5c Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:12:14 -0400 Subject: [PATCH 11/36] SR-07: move laneid()==0 guard into async_barrier_arrive DS_ATOMIC_ASYNC_BARRIER_ARRIVE_B64 fires per active lane. Without a lane guard, one wave produces 32 arrivals instead of 1, causing the barrier to flip prematurely. Moved the guard into the function so callers can't forget it. Removed redundant guards from all 4 call sites in gemm_tdm_arrive. --- include/ops/warp/sync/barrier.cuh | 16 +++++++++------- .../gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp | 13 ++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/include/ops/warp/sync/barrier.cuh b/include/ops/warp/sync/barrier.cuh index 27623038..50f50208 100644 --- a/include/ops/warp/sync/barrier.cuh +++ b/include/ops/warp/sync/barrier.cuh @@ -223,16 +223,18 @@ __device__ __forceinline__ void wait_barrier(uint64_t* bar, int expected_phase) } /** - * @brief Arrive at an LDS barrier cell from an async-ordered path. + * @brief Arrive at an LDS barrier cell from an async-ordered path (once per wave). * - * Lowers to `DS_ATOMIC_ASYNC_BARRIER_ARRIVE_B64`. Use this to manually - * arrive at a cell (the auto-arrive form is encoded in the TDM descriptor - * via `load_tdm_arrive`). + * Lowers to `DS_ATOMIC_ASYNC_BARRIER_ARRIVE_B64` which is a DS atomic that + * fires per active lane. This wrapper restricts execution to lane 0 so the + * barrier receives exactly one arrival per wave. */ __device__ __forceinline__ void async_barrier_arrive(uint64_t* lds_counter) { - uintptr_t lds_uint = reinterpret_cast(lds_counter); - __builtin_amdgcn_ds_atomic_async_barrier_arrive_b64( - reinterpret_cast(lds_uint)); + if (laneid() == 0) { + uintptr_t lds_uint = reinterpret_cast(lds_counter); + __builtin_amdgcn_ds_atomic_async_barrier_arrive_b64( + reinterpret_cast(lds_uint)); + } } } // namespace sync diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp index f609dd6e..495fe181 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp @@ -96,20 +96,19 @@ void gemm_tdm_arrive_kernel(const gemm_globals g, int M, int N, int K) // `load_tdm_arrive`) is also wired in the library for runtimes that // model it natively. // - // `async_barrier_arrive` is a DS atomic, so it fires per active lane: - // guard with `laneid() == 0` so each producer wave arrives exactly - // once per phase (matching the `init_barrier(.., 1)` priming above). + // `async_barrier_arrive` internally guards with `laneid() == 0` so + // each wave arrives exactly once per phase. if (wid == 0) { g2s::load_tdm( A_lds[0], g.a, {0, 0, tile_m, 0}, M, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&A_bar[0].state); + sync::async_barrier_arrive(&A_bar[0].state); } if (wid == 1) { g2s::load_tdm( B_lds[0], g.b, {0, 0, tile_n, 0}, N, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&B_bar[0].state); + sync::async_barrier_arrive(&B_bar[0].state); } for (int k = 0; k < k_iters; ++k) { @@ -120,13 +119,13 @@ void gemm_tdm_arrive_kernel(const gemm_globals g, int M, int N, int K) g2s::load_tdm( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, M, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&A_bar[nxt].state); + sync::async_barrier_arrive(&A_bar[nxt].state); } if (wid == 1) { g2s::load_tdm( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, N, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&B_bar[nxt].state); + sync::async_barrier_arrive(&B_bar[nxt].state); } } From a17ff7b6f9f079f1761e7d88b13a6989e008bb2e Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:18:00 -0400 Subject: [PATCH 12/36] SR-08: assert barrier address fits in D# 16-bit field The TDM D# descriptor's atomic_barrier_address field is 16 bits, limiting barrier cells to the first 64KB of LDS. Added assert at the point of truncation in build_tdm_d_2d() so any kernel passing an out-of-range barrier address fails immediately instead of silently arriving at the wrong LDS location. --- include/ops/warp/memory/tile/global_to_shared.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index 364e23dc..c3319a3d 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -673,6 +673,7 @@ __device__ __forceinline__ void build_tdm_d_2d( const uint32_t tiledim1 = static_cast(ROWS); // barrier_addr occupies w1[15:0]; tensor_dim0 lo16 occupies w1[31:16]. + assert(bar_lds_addr == 0 || bar_lds_addr < 0x10000u); uint32_t w1 = (bar_lds_addr & 0xFFFFu) | (tdim0 << 16); uint32_t w2 = (tdim0 >> 16) | (tdim1 << 16); uint32_t w3 = (tdim1 >> 16) | (tiledim0 << 16); From 3dbae36289d385a905c9e2367c934a78f3b7bdb2 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:20:48 -0400 Subject: [PATCH 13/36] SR-09: static_assert pad interval is power-of-2 for D# encoding __builtin_ctz encodes pad_interval as log2 for the TDM D# descriptor. Non-power-of-2 intervals produce wrong encoding (ctz returns lowest set bit, not log2). static_assert catches this at compile time. --- include/ops/warp/memory/tile/global_to_shared.cuh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index c3319a3d..231d5ba5 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -649,6 +649,9 @@ __device__ __forceinline__ void build_tdm_d_2d( : (sizeof(T) == 4) ? 2 : 3; constexpr uint32_t pad_enable = (Pad::interval > 0) ? 1u : 0u; + static_assert(Pad::interval == 0 || + __builtin_popcount(Pad::interval * sizeof(T) / 4) == 1, + "Pad interval in DWords must be a power of 2 for D# encoding"); constexpr uint32_t pad_int_enc = (Pad::interval > 0) ? ( __builtin_ctz(Pad::interval * sizeof(T) / 4) ) : 0; constexpr uint32_t pad_amt_enc = (Pad::amount > 0) From 07541812d2e9b8361b99c9d4439ca5668346792c Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:23:38 -0400 Subject: [PATCH 14/36] SR-10: fix chiplet_transform_chunked off-by-one (> to >=) Boundary workgroup at limit should be in the partial tail (identity mapping), not enter the full-block remapping. Change > to >=. Cosmetic on all tested parameters (boundary always maps to itself), but matches the documented intent. No-op on gfx1250 (NUM_XCDS=1). --- include/common/util.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/common/util.cuh b/include/common/util.cuh index 7fe0a89c..c992e138 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -109,7 +109,7 @@ __host__ __device__ inline int ceil_div(int a, int b) { int limit = (num_workgroups / block) * block; // If pid beyond the last full block, leave unchanged - if (workgroup_id > limit) return workgroup_id; + if (workgroup_id >= limit) return workgroup_id; // Local PID (within round-robin assignment) int local_pid = workgroup_id / num_xcds; From 3c45296ef691f5690f09dddbcb687901a93d914b Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:26:02 -0400 Subject: [PATCH 15/36] SR-11/SR-12: document padding derivation and write/read consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SR-11: Added bank conflict derivation for lds_pad_default<128,8>. SR-12: Added comment that Pad type must match between load_async (write) and load_b128 (read) — mismatch causes silent corruption. --- include/ops/warp/memory/tile/global_to_shared.cuh | 6 ++++++ kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp | 1 + 2 files changed, 7 insertions(+) diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index 231d5ba5..4065beaf 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -461,6 +461,12 @@ struct lds_nopad { }; /// @brief Default LDS padding for bf16 GEMMs on gfx1250. +/// Derivation: gfx1250 LDS has 32 banks, 4 bytes wide. A 16x32 bf16 subtile +/// row is 32 * 2 = 64 bytes = 16 banks. Two consecutive rows hit the same 16 +/// banks → 2-way conflict on ds_load_b128. Padding by 8 bf16 (16 bytes = 4 +/// banks) shifts each row's bank mapping, eliminating conflicts. +/// Interval 128 = one subtile row (128 bf16 = 256 bytes). Must be power-of-2 +/// for D# encoding (see SR-09 static_assert). using lds_pad_default = lds_padded<128, 8>; namespace g2s { diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp index 5faefd1e..be3041b1 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp @@ -13,6 +13,7 @@ using namespace kittens; using namespace gfx1250_gemm; +// Pad type must match between load_async (write) and load_b128 (read). using Pad = lds_pad_default; constexpr int A_ELEMS_PAD = Pad::padded_elems(BLOCK_M * K_STEP); constexpr int B_ELEMS_PAD = Pad::padded_elems(BLOCK_N * K_STEP); From 756fbd097756f771d1f0f1810098eb3fe7b0ed3b Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:27:07 -0400 Subject: [PATCH 16/36] SR-13/SR-14: segment size static_assert and wait_barrier hang warning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SR-13: Added static_assert that double-buffered A and B each fit in one LDS segment (64KB). Catches oversized tile configs at compile time. SR-14: Added WARNING comment to wait_barrier documenting infinite loop risk if arrive never fires. Timeout deferred — can't test on the gfx1250 simulator. --- include/ops/warp/sync/barrier.cuh | 3 +++ kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/include/ops/warp/sync/barrier.cuh b/include/ops/warp/sync/barrier.cuh index 50f50208..4fd011a3 100644 --- a/include/ops/warp/sync/barrier.cuh +++ b/include/ops/warp/sync/barrier.cuh @@ -207,6 +207,9 @@ __device__ __forceinline__ void init_barrier(uint64_t* bar, uint32_t count) { * Callers maintain a parity bit per barrier and pass it inverted before * each wait (`expected = (phase ^= 1)`). The hardware wakes sleeping * waves on phase flip; `s_sleep 1` yields the SIMD between polls. + * + * WARNING: infinite loop if the matching arrive never fires (e.g., wrong + * init_barrier count or missing async_barrier_arrive). No timeout. */ __device__ __forceinline__ void wait_barrier(uint64_t* bar, int expected_phase) { const uint32_t lds_addr = static_cast(reinterpret_cast(bar)); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp index 083887b2..c43b66ca 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp @@ -12,9 +12,14 @@ using namespace kittens; using namespace gfx1250_gemm; +// Pad type must match between load_async (write) and load_b128 (read). using Pad = lds_pad_default; constexpr int A_ELEMS_PAD = Pad::padded_elems(BLOCK_M * K_STEP); constexpr int B_ELEMS_PAD = Pad::padded_elems(BLOCK_N * K_STEP); +static_assert(2 * A_ELEMS_PAD * sizeof(kittens::bf16) <= kittens::LDS_SEGMENT_BYTES, + "Double-buffered A must fit in one LDS segment"); +static_assert(2 * B_ELEMS_PAD * sizeof(kittens::bf16) <= kittens::LDS_SEGMENT_BYTES, + "Double-buffered B must fit in one LDS segment"); __global__ __launch_bounds__(NUM_THREADS, 1) void gemm_segment_kernel(const gemm_globals g, int M, int N, int K) From fcca9b8d42d5e8dbde560aaec70429ac38c21422 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:33:46 -0400 Subject: [PATCH 17/36] IV-01: skip last-iteration sync/wait_async in all 7 ladder kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The final loop iteration syncs to protect a next-iteration load that never happens. Guard with if(k+1 < k_iters) — uniform scalar branch, no divergence. ISA confirms the barrier is conditionally skipped via s_cbranch. All 7 kernels pass correctness on the gfx1250 simulator. --- kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp | 6 ++++-- kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp | 2 +- kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp | 2 +- kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp | 2 +- kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp | 6 ++++-- kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp | 2 +- kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp | 2 +- 7 files changed, 13 insertions(+), 9 deletions(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp index 43e7a7ad..e83e20ce 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp @@ -59,8 +59,10 @@ void gemm_async_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); - kittens::sync::sync(); + if (k + 1 < k_iters) { // skip on last iter — no next load issued + kittens::sync::wait_async(); + kittens::sync::sync(); + } } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp index e4db519c..a444a3c9 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp @@ -56,7 +56,7 @@ void gemm_double_buf_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::sync(); + if (k + 1 < k_iters) kittens::sync::sync(); // skip on last iter — no next overwrite } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp index 97555ff4..9d675092 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp @@ -69,7 +69,7 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt_burst(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); + if (k + 1 < k_iters) kittens::sync::wait_async(); // skip on last iter — no next load issued } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp index a87530ae..fde01745 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp @@ -54,7 +54,7 @@ void gemm_naive_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::sync(); + if (k + 1 < k_iters) kittens::sync::sync(); // skip on last iter — no next overwrite } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp index be3041b1..46801cf2 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp @@ -64,8 +64,10 @@ void gemm_padded_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); - kittens::sync::sync(); + if (k + 1 < k_iters) { // skip on last iter — no next load issued + kittens::sync::wait_async(); + kittens::sync::sync(); + } } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp index c43b66ca..a5c2e5fd 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp @@ -69,7 +69,7 @@ void gemm_segment_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); + if (k + 1 < k_iters) kittens::sync::wait_async(); // skip on last iter — no next load issued } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp index b4a45e02..0d9a0b5e 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp @@ -64,7 +64,7 @@ void gemm_split_bar_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); + if (k + 1 < k_iters) kittens::sync::wait_async(); // skip on last iter — no next load issued } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); From a1e99e1531e1ae3e20377716748248b5f48379bf Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:36:46 -0400 Subject: [PATCH 18/36] IV-02: document why expert mode scope is intentionally wide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Attempted to scope RAII guard tightly around mma_ABt_burst. ISA showed compiler reordering scalar s_setreg past VALU WMMA instructions, making the expert mode window empty. Wide scope is correct — ensures s_setreg reset stays after all WMMAs. Added comment. --- kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp index 9d675092..bbc3f18b 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp @@ -38,7 +38,7 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) const int warp_c = wid % WARPS_N; const int k_iters = K / K_STEP; - kittens::sched::expert _sched; // limited expert mode, off in dtor + kittens::sched::expert _sched; // covers entire kernel — compiler reorders s_setreg if scoped tighter kittens::g2s::load_async( A_lds[0], g.a, {0, 0, tile_m, 0}, K); From 8d011edbd57ac8814f875145019e09f73069c602 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:40:55 -0400 Subject: [PATCH 19/36] IV-03: move arrive() before async loads in split_bar, segment, expert arrive() signals completion of previous iteration's LDS reads. Async loads write to nxt slab (different from cur), so signaling before them is safe. ISA confirms s_barrier_signal moved 24 instructions earlier, allowing other waves to unblock from wait() sooner. All 3 kernels pass correctness on the gfx1250 simulator. --- kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp | 2 +- kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp | 2 +- kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp index bbc3f18b..720e091a 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp @@ -50,13 +50,13 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) for (int k = 0; k < k_iters; ++k) { const int cur = k & 1, nxt = 1 - cur; + kittens::sync::arrive(); // signal before async loads — nxt != cur, no conflict if (k + 1 < k_iters) { kittens::g2s::load_async( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, K); kittens::g2s::load_async( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, K); } - kittens::sync::arrive(); rt_bf A_reg; rt_bf B_reg; diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp index a5c2e5fd..15bf9862 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp @@ -50,13 +50,13 @@ void gemm_segment_kernel(const gemm_globals g, int M, int N, int K) for (int k = 0; k < k_iters; ++k) { const int cur = k & 1, nxt = 1 - cur; + kittens::sync::arrive(); // signal before async loads — nxt != cur, no conflict if (k + 1 < k_iters) { kittens::g2s::load_async( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, K); kittens::g2s::load_async( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, K); } - kittens::sync::arrive(); rt_bf A_reg; rt_bf B_reg; diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp index 0d9a0b5e..1b98ac86 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp @@ -45,13 +45,13 @@ void gemm_split_bar_kernel(const gemm_globals g, int M, int N, int K) for (int k = 0; k < k_iters; ++k) { const int cur = k & 1, nxt = 1 - cur; + kittens::sync::arrive(); // signal before async loads — nxt != cur, no conflict if (k + 1 < k_iters) { kittens::g2s::load_async( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, K); kittens::g2s::load_async( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, K); } - kittens::sync::arrive(); // signal early -- independent work below rt_bf A_reg; rt_bf B_reg; From c069417fea937a0afa455333bdf12c90fba2676b Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:43:32 -0400 Subject: [PATCH 20/36] CQ-01: replace -w with -Wall and targeted suppressions -w hid all compiler warnings. Now using -Wall with specific -Wno flags for upstream framework warnings (unused-variable, unused-local- typedef, duplicate-decl-specifier, unused-value, pass-failed). All 8 kernels build with 0 warnings. --- kernels/gemm/bf16fp32/gfx1250/Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/Makefile b/kernels/gemm/bf16fp32/gfx1250/Makefile index 4d3ed180..d4637247 100644 --- a/kernels/gemm/bf16fp32/gfx1250/Makefile +++ b/kernels/gemm/bf16fp32/gfx1250/Makefile @@ -28,7 +28,8 @@ HIPFLAGS := -DKITTENS_UDNA1 \ CXX_STD := c++20 CPPFLAGS += -I$(KITTENS_ROOT)/include -I$(HIP_INCLUDE_DIR) -CXXFLAGS += -std=$(CXX_STD) -O3 -w +CXXFLAGS += -std=$(CXX_STD) -O3 -Wall -Wno-unused-variable -Wno-unused-local-typedef \ + -Wno-duplicate-decl-specifier -Wno-unused-value -Wno-pass-failed M ?= 256 N ?= 256 From 94a463a437b47292e002551f784dafc370b347d8 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:45:49 -0400 Subject: [PATCH 21/36] CQ-01/CQ-02/CQ-18: Makefile warning flags, header deps, all target CQ-01: Replace -w with -Wall + targeted -Wno suppressions. 0 warnings. CQ-02: Add HK_HEADERS wildcard dep so header changes trigger rebuild. CQ-18: Change `all` target from $(BIN) to `ladder` (builds all 8). --- kernels/gemm/bf16fp32/gfx1250/Makefile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/Makefile b/kernels/gemm/bf16fp32/gfx1250/Makefile index d4637247..b57d4ac4 100644 --- a/kernels/gemm/bf16fp32/gfx1250/Makefile +++ b/kernels/gemm/bf16fp32/gfx1250/Makefile @@ -39,11 +39,13 @@ VERIFY ?= 1 .PHONY: ladder run clean all +HK_HEADERS := $(shell find $(KITTENS_ROOT)/include -name '*.cuh' 2>/dev/null) + # Default target: build the currently selected KERNEL. -$(BIN): $(SRC) harness.h common.h +$(BIN): $(SRC) harness.h common.h $(HK_HEADERS) $(HIPCXX) $(HIPFLAGS) $(CXXFLAGS) $(CPPFLAGS) -DHARNESS_MAIN $(SRC) -o $(BIN) -all: $(BIN) +all: ladder # Loop variant -- `make ladder` rebuilds every rung. ladder: From ddddb1ad2dd37dca2f8d077a7f34ec7db6aa3ca2 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:45:49 -0400 Subject: [PATCH 22/36] CQ-03/CQ-05: fix kernel name printf and add hipGetLastError after launch CQ-03: Use __FILE__ instead of hardcoded "gemm_naive" in printf. CQ-05: Add HIP_OK(hipGetLastError()) after warmup dispatch to catch silent launch failures (e.g., invalid shared memory size). --- kernels/gemm/bf16fp32/gfx1250/harness.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/harness.h b/kernels/gemm/bf16fp32/gfx1250/harness.h index 5b974121..65ce52c5 100644 --- a/kernels/gemm/bf16fp32/gfx1250/harness.h +++ b/kernels/gemm/bf16fp32/gfx1250/harness.h @@ -61,8 +61,8 @@ int main(int argc, char** argv) int n_iters = (argc > 4) ? std::atoi(argv[4]) : 1; int verify = (argc > 5) ? std::atoi(argv[5]) : 1; - std::printf("gemm_naive (bf16->fp32->bf16) M=%d N=%d K=%d iters=%d verify=%d\n", - M, N, K, n_iters, verify); + std::printf("%s (bf16->fp32->bf16) M=%d N=%d K=%d iters=%d verify=%d\n", + __FILE__, M, N, K, n_iters, verify); // ---- host fp32 reference + bf16 buffers ---- std::vector A_h(M * K), B_h(N * K), C_ref(M * N); @@ -96,6 +96,7 @@ int main(int argc, char** argv) // ---- warmup + timed run ---- dispatch(g); + HIP_OK(hipGetLastError()); HIP_OK(hipDeviceSynchronize()); hipEvent_t t0, t1; From e749779e4607bec9d9558ca61ba8ce385057041a Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:45:49 -0400 Subject: [PATCH 23/36] CQ-07/CQ-11: fix MASK_ALL for wave-32 and remove dead static_asserts CQ-07: MASK_ALL now 0xFFFFFFFF on gfx1250 (wave-32), 0xFFFFFFFFFFFFFFFF on CDNA (wave-64). CQ-11: Removed commented-out static_asserts in shared_allocator. --- include/common/util.cuh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/include/common/util.cuh b/include/common/util.cuh index c992e138..eea306be 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -174,7 +174,11 @@ struct default_type {}; /** * @brief Mask constant for all active threads in a warp. */ -static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF; +#ifdef KITTENS_UDNA1 +static constexpr uint64_t MASK_ALL = 0x00000000FFFFFFFF; // wave-32 +#else +static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF; // wave-64 +#endif /** * @brief Perform a shuffle down operation on a packed type synchronously across a warp. @@ -381,7 +385,7 @@ struct shared_allocator { */ template __device__ inline variadic_array_t& allocate() { - // static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation"); + align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); @@ -398,7 +402,7 @@ struct shared_allocator { */ template __device__ inline variadic_array_t& allocate() { - // static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation"); + align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); From 20880566548dbfa08f3fe3b79474aaff0f98f2a2 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:45:49 -0400 Subject: [PATCH 24/36] CQ-13: remove unused bar_bytes variable in TDM dispatch --- kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp index 495fe181..53f6ca2c 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp @@ -162,9 +162,7 @@ void dispatch(gemm_globals g) { // Same layout as `gemm_segment`/`gemm_expert` (A in seg 0, B in seg 1) // plus 4 barrier cells in seg 0. - constexpr size_t bar_bytes = 4 * sizeof(sync::barrier_lds); const size_t mem_size = LDS_SEGMENT_BYTES + 2 * B_ELEMS_PAD * sizeof(bf16); - (void)bar_bytes; hipFuncSetAttribute(reinterpret_cast(gemm_tdm_arrive_kernel), hipFuncAttributeMaxDynamicSharedMemorySize, static_cast(mem_size)); From 52bf9feb10aecd0551a7f7ad74e08313b3fa4678 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:45:49 -0400 Subject: [PATCH 25/36] CQ-23: combine paired ds_load_b128 into single asm volatile block Two separate asm volatile blocks allowed compiler to theoretically reorder them. Single block ensures both loads are emitted together. --- include/ops/warp/memory/tile/shared_to_register.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index 2a418543..fff675fc 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -753,10 +753,9 @@ __device__ inline void load_b128( reinterpret_cast(warp_lds_base + padded_off)); float4 lo, hi; - asm volatile("ds_load_b128 %0, %1 offset:0\n" - : "=v"(lo) : "v"(addr) : "memory"); - asm volatile("ds_load_b128 %0, %1 offset:16\n" - : "=v"(hi) : "v"(addr) : "memory"); + asm volatile("ds_load_b128 %0, %2 offset:0\n" + "ds_load_b128 %1, %2 offset:16\n" + : "=v"(lo), "=v"(hi) : "v"(addr) : "memory"); bf16_2* lo_p = reinterpret_cast(&lo); bf16_2* hi_p = reinterpret_cast(&hi); From d1d3153c55c653775c30cbf0d5e1ac917e043194 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:48:39 -0400 Subject: [PATCH 26/36] CQ-04/CQ-06: fix CPU reference precision and scale tolerance with K MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CQ-04: Round-trip CPU inputs through bf16 before reference computation so CPU and GPU see identical inputs. Drops max_abs_err 0.0368 → 0.0311. CQ-06: Replace fixed tolerance (1.0) with K-scaled formula: tol = 2 * sqrt(K) * 2^-7. Require zero bad elements. Passes K=32 (tol=0.088, err=0.031) and K=256 (tol=0.250, err=0.123). --- kernels/gemm/bf16fp32/gfx1250/harness.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/harness.h b/kernels/gemm/bf16fp32/gfx1250/harness.h index 65ce52c5..d0880197 100644 --- a/kernels/gemm/bf16fp32/gfx1250/harness.h +++ b/kernels/gemm/bf16fp32/gfx1250/harness.h @@ -74,6 +74,9 @@ int main(int argc, char** argv) for (auto& x : B_h) x = dist(rng); for (size_t i = 0; i < A_h.size(); ++i) A_bf[i] = float_to_bf16(A_h[i]); for (size_t i = 0; i < B_h.size(); ++i) B_bf[i] = float_to_bf16(B_h[i]); + // Round-trip through bf16 so CPU reference sees same inputs as GPU + for (size_t i = 0; i < A_h.size(); ++i) A_h[i] = bf16_to_float(A_bf[i]); + for (size_t i = 0; i < B_h.size(); ++i) B_h[i] = bf16_to_float(B_bf[i]); // ---- device buffers ---- __hip_bfloat16 *A_d = nullptr, *B_d = nullptr, *C_d = nullptr; @@ -133,7 +136,10 @@ int main(int argc, char** argv) mean_abs /= (M * N); std::printf(" max_abs_err=%.4f mean_abs_err=%.4f bad=%d/%d\n", max_abs, mean_abs, n_bad, M * N); - return (max_abs < 1.0 || n_bad < 10) ? 0 : 1; + // bf16 has ~2^-7 precision; error grows as ~sqrt(K) * 2^-7 per element. + // Output is bf16 so final quantization adds another 2^-7. Use 2x headroom. + double tol = 2.0 * std::sqrt(static_cast(K)) * (1.0 / 128.0); + return (n_bad == 0 && max_abs < tol) ? 0 : 1; } #endif // HARNESS_MAIN From aedc8b9a265bc36fe7976b12c078dfd523f5a4f4 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:48:52 -0400 Subject: [PATCH 27/36] CQ-14: add hipFree and hipEventDestroy cleanup in harness --- kernels/gemm/bf16fp32/gfx1250/harness.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/kernels/gemm/bf16fp32/gfx1250/harness.h b/kernels/gemm/bf16fp32/gfx1250/harness.h index d0880197..c44e0362 100644 --- a/kernels/gemm/bf16fp32/gfx1250/harness.h +++ b/kernels/gemm/bf16fp32/gfx1250/harness.h @@ -139,7 +139,11 @@ int main(int argc, char** argv) // bf16 has ~2^-7 precision; error grows as ~sqrt(K) * 2^-7 per element. // Output is bf16 so final quantization adds another 2^-7. Use 2x headroom. double tol = 2.0 * std::sqrt(static_cast(K)) * (1.0 / 128.0); - return (n_bad == 0 && max_abs < tol) ? 0 : 1; + int ret = (n_bad == 0 && max_abs < tol) ? 0 : 1; + + hipFree(A_d); hipFree(B_d); hipFree(C_d); + hipEventDestroy(t0); hipEventDestroy(t1); + return ret; } #endif // HARNESS_MAIN From 1802e7b74919075e96089f6d0c9345ecde671666 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:49:06 -0400 Subject: [PATCH 28/36] CQ-19: add isa target to Makefile for GPU assembly dump make isa KERNEL=gemm_naive produces .s file via --save-temps. --- kernels/gemm/bf16fp32/gfx1250/Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kernels/gemm/bf16fp32/gfx1250/Makefile b/kernels/gemm/bf16fp32/gfx1250/Makefile index b57d4ac4..c5517d73 100644 --- a/kernels/gemm/bf16fp32/gfx1250/Makefile +++ b/kernels/gemm/bf16fp32/gfx1250/Makefile @@ -54,5 +54,9 @@ ladder: run: $(BIN) ./$(BIN) $(M) $(N) $(K) $(ITERS) $(VERIFY) +isa: $(SRC) harness.h common.h $(HK_HEADERS) + $(HIPCXX) $(HIPFLAGS) $(CXXFLAGS) $(CPPFLAGS) -DHARNESS_MAIN --save-temps $(SRC) -o $(BIN) + @echo "ISA: $(KERNEL)-hip-amdgcn-amd-amdhsa-gfx1250.s" + clean: rm -f *.out From ac71d74ba957019b57d9f5d99cdad78f20985970 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:49:20 -0400 Subject: [PATCH 29/36] =?UTF-8?q?CQ-12:=20document=20fence()=20scope=20?= =?UTF-8?q?=E2=80=94=20drains=20loadcnt=20+=20dscnt=20only?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/ops/warp/sync/barrier.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ops/warp/sync/barrier.cuh b/include/ops/warp/sync/barrier.cuh index 4fd011a3..8e002200 100644 --- a/include/ops/warp/sync/barrier.cuh +++ b/include/ops/warp/sync/barrier.cuh @@ -154,6 +154,7 @@ __device__ __forceinline__ void wait_tensor() { * Convenience for the common "producer side" pattern: ensure all in-flight * loads have settled into LDS before signalling consumers. */ +// Drains loadcnt + dscnt only — does not wait on store, async, or tensor counters. __device__ __forceinline__ void fence() { wait_load<0>(); wait_ds<0>(); From 50d19fae5e460f83366899ee2aaf3b38cb178982 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:50:08 -0400 Subject: [PATCH 30/36] CQ-16: extract gfx1250_lane_offset helper for shared_to_register Identical offset computation was duplicated between load_b128 and load_b32. Extracted into detail::gfx1250_lane_offset(sub_id, row, half). --- .../ops/warp/memory/tile/shared_to_register.cuh | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index fff675fc..f39a6638 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -709,6 +709,10 @@ namespace detail { inline constexpr int GFX1250_SUB_ROWS = 16; inline constexpr int GFX1250_SUB_COLS = 32; inline constexpr int GFX1250_SUB_ELEMS = GFX1250_SUB_ROWS * GFX1250_SUB_COLS; + +__device__ __forceinline__ int gfx1250_lane_offset(int sub_id, int row, int half) { + return sub_id * GFX1250_SUB_ELEMS + row * GFX1250_SUB_COLS + half * 16; +} } // namespace detail /** @@ -744,10 +748,7 @@ __device__ inline void load_b128( #pragma unroll for (int tj = 0; tj < width; tj++) { const int sub_id = ti * subs_per_row + tj; - const int base_flat = sub_id * detail::GFX1250_SUB_ELEMS - + row * detail::GFX1250_SUB_COLS - + half * 16; - const int padded_off = Pad::padded(base_flat); + const int padded_off = Pad::padded(detail::gfx1250_lane_offset(sub_id, row, half)); const uint32_t addr = static_cast( reinterpret_cast(warp_lds_base + padded_off)); @@ -797,13 +798,10 @@ __device__ inline void load_b32( for (int ti = 0; ti < height; ti++) { #pragma unroll for (int tj = 0; tj < width; tj++) { - const int sub_id = ti * subs_per_row + tj; - const int base_flat = sub_id * detail::GFX1250_SUB_ELEMS - + row * detail::GFX1250_SUB_COLS - + half * 16; + const int sub_id = ti * subs_per_row + tj; const bf16_2* lds_p = reinterpret_cast( - warp_lds_base + base_flat); + warp_lds_base + detail::gfx1250_lane_offset(sub_id, row, half)); #pragma unroll for (int k = 0; k < 8; k++) { From 17d0275290106e7bceaa9487754c9d3a5700660e Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:50:38 -0400 Subject: [PATCH 31/36] CQ-17: static_assert subtile dims are power-of-2 in subtile_flat Integer division by sub_elems and SUB_COLS compiles to shifts only when these are power-of-2. Added static_asserts to catch non-power-of-2 subtile configs that would emit expensive division. --- include/ops/warp/memory/tile/global_to_shared.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index 4065beaf..b770fe59 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -486,6 +486,10 @@ using i32x4_lvec = int __attribute__((__vector_size__(16))) __attribute__((addr */ template __device__ __forceinline__ int subtile_flat(int flat) { + static_assert((SUB_ROWS * SUB_COLS & (SUB_ROWS * SUB_COLS - 1)) == 0, + "sub_elems must be power-of-2 to avoid integer division"); + static_assert((SUB_COLS & (SUB_COLS - 1)) == 0, + "SUB_COLS must be power-of-2 to avoid integer division"); constexpr int sub_elems = SUB_ROWS * SUB_COLS; constexpr int subs_per_row = COLS / SUB_COLS; const int subtile_id = flat / sub_elems; From 16e5e1a678025fdf9df963705b616d6a0edb5d47 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 18:56:39 -0400 Subject: [PATCH 32/36] CQ-10/CQ-20/CQ-21: auto-match padded shapes, padding name clarity, segment docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CQ-10: Replace manual st_16x32_padded<> enumeration in UDNA1 concept with template specialization (is_st_16x32_padded_inst). Any padded config auto-satisfies the concept without manual addition. Keeps the closed whitelist — only named shapes + padded instantiations pass. CDNA concept left unchanged. CQ-20: Add comment that st_16x32_padded::swizzle() applies padding, not XOR swizzle (named for API compat). CQ-21: Document segment 0 LDS waste tradeoff in gemm_segment. --- .../ops/warp/memory/tile/global_to_shared.cuh | 7 +++++++ include/types/shared/st_shape.cuh | 18 ++++++++++-------- kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp | 2 ++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index b770fe59..9e9feb91 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -10,6 +10,13 @@ namespace kittens { +// CDNA global→shared load overloads (4 variants): +// 1. load(ST, GL, coord) — typed ST, computes swizzle inline +// 2. load(ST, GL, coord, precomp) — typed ST, takes precomputed swizzle offsets +// 3. load_async(ST, GL, coord) — async DMA path (CDNA2+) +// 4. store(GL, ST, coord) — shared→global (reverse direction) +// For gfx1250, see the g2s:: namespace below for register-mediated and TDM paths. + template, diff --git a/include/types/shared/st_shape.cuh b/include/types/shared/st_shape.cuh index 404ee776..4eff8911 100644 --- a/include/types/shared/st_shape.cuh +++ b/include/types/shared/st_shape.cuh @@ -296,6 +296,7 @@ struct st_16x32_padded { } } + // Named "swizzle" for API compat with CDNA shapes, but applies padding (not XOR swizzle). template __device__ __forceinline__ static const uint32_t swizzle(int2 coord) { const int r = coord.x, c = coord.y; @@ -309,18 +310,19 @@ struct st_16x32_padded { } }; +template struct is_st_16x32_padded_inst : std::false_type {}; +template struct is_st_16x32_padded_inst> : std::true_type {}; + template -concept all = std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || +concept all = std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v> || - std::is_same_v> || - std::is_same_v>; + is_st_16x32_padded_inst::value; #else template concept all = std::is_same_v || diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp index 15bf9862..f2a66782 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp @@ -13,6 +13,8 @@ using namespace kittens; using namespace gfx1250_gemm; // Pad type must match between load_async (write) and load_b128 (read). +// A uses ~8.7KB of a 64KB segment — the rest is unused but reserves the segment +// for dual-port LDS access (A and B on separate ports). using Pad = lds_pad_default; constexpr int A_ELEMS_PAD = Pad::padded_elems(BLOCK_M * K_STEP); constexpr int B_ELEMS_PAD = Pad::padded_elems(BLOCK_N * K_STEP); From 9777baa09b651cae879fc5b214f92407e16b7350 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 19:01:50 -0400 Subject: [PATCH 33/36] MF-02: add SCHED_MODE bit[4] DISABLE_VALU_STALL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added limited_nostall mode (limited + bit[4]) to sched::mode enum. Widened HWREG encoding from 2 to 5 bits to cover the full SCHED_MODE field. Applied to gemm_expert kernel — ISA confirms s_setreg value changed from 2 to 18. Correctness verified on the gfx1250 simulator. --- include/ops/warp/sched/sched.cuh | 13 +++++++------ kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/include/ops/warp/sched/sched.cuh b/include/ops/warp/sched/sched.cuh index acd3098d..d2b00953 100644 --- a/include/ops/warp/sched/sched.cuh +++ b/include/ops/warp/sched/sched.cuh @@ -28,14 +28,15 @@ namespace sched { * as experimental / unsafe by default. */ enum class mode : int { - normal = 0, - full = 1, - limited = 2, + normal = 0, + full = 1, + limited = 2, + limited_nostall = 2 | (1 << 4), // limited + DISABLE_VALU_STALL (bit[4]) }; -// `s_setreg_b32 hwreg(MODE_REG=1, offset=4, size=2), value` -// Encoded simm16 = 1 | (4 << 6) | ((2-1) << 11) = 2305. -constexpr int SCHED_MODE_HWREG_SIMM16 = 1 | (4 << 6) | (1 << 11); +// `s_setreg_b32 hwreg(MODE_REG=1, offset=4, size=5), value` +// 5 bits to cover bits [4:0] of SCHED_MODE (including DISABLE_VALU_STALL at bit[4]). +constexpr int SCHED_MODE_HWREG_SIMM16 = 1 | (4 << 6) | ((5 - 1) << 11); /** * @brief Set the wave's SCHED_MODE to `m`. diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp index 720e091a..b3dbf055 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp @@ -38,7 +38,7 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) const int warp_c = wid % WARPS_N; const int k_iters = K / K_STEP; - kittens::sched::expert _sched; // covers entire kernel — compiler reorders s_setreg if scoped tighter + kittens::sched::expert _sched{kittens::sched::mode::limited_nostall}; // DISABLE_VALU_STALL for WMMA burst kittens::g2s::load_async( A_lds[0], g.a, {0, 0, tile_m, 0}, K); From 9b58ced213dd739f680b51d296f56cfcb95c5f4f Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 19:03:10 -0400 Subject: [PATCH 34/36] MF-08: add s_wakeup and s_sleep_var wrappers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sched::wakeup() wakes all sleeping waves in the workgroup — useful for waking consumers after producer finishes. sched::sleep_var() sleeps for a runtime-variable duration (SGPR * 64 cycles). Both use inline asm since no builtins exist in ROCm 7.x. --- include/ops/warp/sched/sched.cuh | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/include/ops/warp/sched/sched.cuh b/include/ops/warp/sched/sched.cuh index d2b00953..ff8b41d1 100644 --- a/include/ops/warp/sched/sched.cuh +++ b/include/ops/warp/sched/sched.cuh @@ -105,6 +105,25 @@ __device__ __forceinline__ void sleep() { __builtin_amdgcn_s_sleep(N); } +/** + * @brief Wake all sleeping waves in this workgroup. + * + * Lowers to `s_wakeup`. Use after a producer finishes work to wake consumer + * waves that are polling in `wait_barrier` via `s_sleep`. + */ +__device__ __forceinline__ void wakeup() { + asm volatile("s_wakeup" ::: "memory"); +} + +/** + * @brief Sleep the wave for a runtime-variable number of cycles. + * + * Lowers to `s_sleep_var`. Duration = SGPR[6:0] * 64 cycles. + */ +__device__ __forceinline__ void sleep_var(unsigned cycles_div64) { + asm volatile("s_sleep_var %0" :: "s"(cycles_div64)); +} + /** * @brief Compiler-only scheduling fence. * From 12cdbab6beae0c7b18835c2646bc3ccaefa44cc0 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 19:03:29 -0400 Subject: [PATCH 35/36] MF-09: document WMMA co-execution opportunity Each WMMA takes 16 cycles; the SIMD can issue up to 8 independent VALU ops for free during this time. Added documentation near compiler_fence(). --- include/ops/warp/sched/sched.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/ops/warp/sched/sched.cuh b/include/ops/warp/sched/sched.cuh index ff8b41d1..fdeec391 100644 --- a/include/ops/warp/sched/sched.cuh +++ b/include/ops/warp/sched/sched.cuh @@ -130,6 +130,12 @@ __device__ __forceinline__ void sleep_var(unsigned cycles_div64) { * Tells the LLVM scheduler not to reorder instructions across this point * but emits no hardware op. Useful when constraining the compiler's WMMA * burst grouping without paying a runtime barrier. + * + * WMMA co-execution note: each WMMA takes 16 cycles. During this time, the + * SIMD can issue up to 8 independent VALU ops for free (1 per 2 cycles). + * Place address computation, format conversion, or accumulator scaling + * between WMMA instructions to exploit this. The compiler does this + * automatically when independent work is available in the same basic block. */ __device__ __forceinline__ void compiler_fence() { __builtin_amdgcn_sched_barrier(0); From aae040163550a9379d9e15e0641d4b4a78bdabb3 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sat, 23 May 2026 19:14:18 -0400 Subject: [PATCH 36/36] MF-01: wrap named barriers (IDs 1-16) for subset-of-waves sync MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added init_named(), join(), signal(), wait_named(), and wakeup_barrier(). s_barrier_leave is not available on gfx1250 (assembler rejects it). Named barriers require dispatch-time allocation via HIP launch attributes — wrappers handle the device-side instructions only. Verified: all wrappers compile and emit correct ISA on gfx1250. --- include/ops/warp/register/tile/mma.cuh | 17 +++++++-- include/ops/warp/sync/barrier.cuh | 48 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/include/ops/warp/register/tile/mma.cuh b/include/ops/warp/register/tile/mma.cuh index 4e7b79f3..cf4421bd 100644 --- a/include/ops/warp/register/tile/mma.cuh +++ b/include/ops/warp/register/tile/mma.cuh @@ -220,12 +220,24 @@ __device__ static inline void mma_AB_base(rt_base && +#ifdef KITTENS_UDNA1 + // gfx1250 WMMA always computes A × B_input^T. For mma_AB, B is col-major, + // so B_input^T = B_row = the non-transposed view. Same WMMA instruction. + if constexpr (std::is_same_v && + A_rows == 16 && A_cols == 32 && + B_rows == 32 && B_cols == 16 && + std::is_same_v) { + wmma161632(d.data, a.data, b.data, c.data); + } else { + static_assert(false, "Unsupported shape combination for gfx1250 mma_AB_base"); + } +#else + if constexpr (std::is_same_v && A_rows == 16 && A_cols == 32 && B_rows == 32 && B_cols == 16 && std::is_same_v) { mfma161632(d.data, a.data, b.data, c.data); - } else if constexpr (std::is_same_v && + } else if constexpr (std::is_same_v && A_rows == 32 && A_cols == 16 && B_rows == 16 && B_cols == 32 && std::is_same_v) { @@ -233,6 +245,7 @@ __device__ static inline void mma_AB_base(rt_base(); } +/* ---------- NAMED BARRIERS (IDs 1-16) ---------- */ +// +// Hardware-managed subset-of-waves barriers. Unlike the workgroup barrier (-1) +// which syncs ALL waves, a named barrier only syncs waves that have JOINed it. +// Waves can be a member of at most one named barrier at a time. +// +// Usage pattern: +// 1. One wave per barrier calls init_named(id, member_count) +// 2. All participating waves call join(id) +// 3. Sync via signal(id) + wait(id) +// 4. When done, call leave(id) + +/// @brief Initialize named barrier `id` with `member_count` waves. +__device__ __forceinline__ void init_named(int id, int member_count) { + unsigned m0_val = (static_cast(member_count) << 16) | static_cast(id); + asm volatile("s_mov_b32 m0, %0\n" + "s_barrier_init m0" + :: "s"(m0_val) : "memory"); +} + +/// @brief Current wave joins named barrier `id`. +template +__device__ __forceinline__ void join() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + asm volatile("s_barrier_join %0" :: "I"(ID) : "memory"); +} + +/// @brief Signal named barrier `id`. +template +__device__ __forceinline__ void signal() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + __builtin_amdgcn_s_barrier_signal(ID); +} + +/// @brief Wait on named barrier `id`. +template +__device__ __forceinline__ void wait_named() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + __builtin_amdgcn_s_barrier_wait(ID); +} + +/// @brief Wake all waves joined to named barrier `id`. +template +__device__ __forceinline__ void wakeup_barrier() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + asm volatile("s_wakeup_barrier %0" :: "I"(ID) : "memory"); +} + /* ---------- LDS BARRIER CELLS (FOR TDM / ASYNC ARRIVE) ---------- */ // // 64-bit LDS-resident barrier cell, per SP3 section 9.8.13