diff --git a/include/common/util.cuh b/include/common/util.cuh index 8b007f40..eea306be 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -109,7 +109,7 @@ __host__ __device__ inline int ceil_div(int a, int b) { int limit = (num_workgroups / block) * block; // If pid beyond the last full block, leave unchanged - if (workgroup_id > limit) return workgroup_id; + if (workgroup_id >= limit) return workgroup_id; // Local PID (within round-robin assignment) int local_pid = workgroup_id / num_xcds; @@ -174,7 +174,11 @@ struct default_type {}; /** * @brief Mask constant for all active threads in a warp. */ -static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF; +#ifdef KITTENS_UDNA1 +static constexpr uint64_t MASK_ALL = 0x00000000FFFFFFFF; // wave-32 +#else +static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF; // wave-64 +#endif /** * @brief Perform a shuffle down operation on a packed type synchronously across a warp. @@ -201,7 +205,7 @@ __device__ static inline T packed_shfl_down(uint64_t mask, const T &f, int delta *reinterpret_cast(&f)}; } - u.ui = __shfl_down_sync(mask, u.ui, delta, 64); + u.ui = __shfl_down_sync(mask, u.ui, delta, WARP_THREADS); if constexpr (std::is_same_v) { return *reinterpret_cast(&u.bf162.x); // Extract single bf16 from the .x component } else { @@ -332,8 +336,9 @@ template concept all = is_segment::value; template struct shared_allocator { int *ptr; -#ifdef KITTENS_UDNA1 int *base; +#ifdef KITTENS_UDNA1 + int *seg_ptr[LDS_NUM_SEGMENTS]; #endif private: @@ -366,24 +371,26 @@ struct shared_allocator { * @brief Construct a new shared allocator using a pointer to extern shared memory. * @param[in] _ptr Pointer to the start of the extern shared memory. */ + __device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) { #ifdef KITTENS_UDNA1 - __device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) {} -#else - __device__ shared_allocator(int *_ptr): ptr(_ptr) {} + for (int i = 0; i < LDS_NUM_SEGMENTS; i++) + seg_ptr[i] = base + i * (LDS_SEGMENT_BYTES / (int)sizeof(int)); #endif + } /** * @brief Allocate shared memory for a single instance or N-dimensional array of type A. * @tparam A The type of the object to allocate. * @tparam dims... A list of dimensions for the N-dimensional array. * @return Reference to the allocated object. */ - template + template __device__ inline variadic_array_t& allocate() { - // static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation"); + align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); ptr += sizeof(at)/sizeof(int); + assert(ptr <= base + MAX_SHARED_MEMORY / sizeof(int)); return *p; } /** @@ -393,13 +400,14 @@ struct shared_allocator { * @tparam dims... A list of dimensions for the N-dimensional array. * @return Reference to the allocated object. */ - template + template __device__ inline variadic_array_t& allocate() { - // static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation"); + align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); ptr += sizeof(at)/sizeof(int); + assert(ptr <= base + MAX_SHARED_MEMORY / sizeof(int)); return *p; } @@ -419,13 +427,17 @@ struct shared_allocator { template requires ducks::segment_tag::all __device__ inline variadic_array_t& allocate_in() { - int* target = base + (SEG::byte_offset / sizeof(int)); - // If we've already allocated past the requested segment, keep - // packing where we are; otherwise jump forward to the segment. - if (ptr < target) ptr = target; + constexpr int idx = SEG::index; + if constexpr (default_alignment > 0) { + uint64_t p = reinterpret_cast(seg_ptr[idx]); + if (p % default_alignment != 0) + seg_ptr[idx] = (int*)(p + (default_alignment - (p % default_alignment))); + } using at = variadic_array_t; - at* p = reinterpret_cast(ptr); - ptr += sizeof(at) / sizeof(int); + at* p = reinterpret_cast(seg_ptr[idx]); + seg_ptr[idx] += sizeof(at) / sizeof(int); + constexpr int seg_end = (idx + 1) * LDS_SEGMENT_BYTES / sizeof(int); + assert(seg_ptr[idx] <= base + seg_end); return *p; } #endif // KITTENS_UDNA1 diff --git a/include/ops/warp/cluster/cluster.cuh b/include/ops/warp/cluster/cluster.cuh index 48d2902c..7d72547f 100644 --- a/include/ops/warp/cluster/cluster.cuh +++ b/include/ops/warp/cluster/cluster.cuh @@ -44,14 +44,14 @@ __device__ __host__ __forceinline__ constexpr uint32_t mask( } /** - * @brief Cluster-wide split barrier. + * @brief Cluster-wide barrier (signal + wait on cluster user barrier -3). * - * Outside a CGA cluster this lowers to a workgroup-wide `sync::sync()`. Inside - * a cluster the same `s_barrier_signal -1 / s_barrier_wait -1` pair extends to - * every workgroup in the cluster by hardware-managed forwarding. + * Barrier -3 syncs across all workgroups in a CGA cluster. + * Outside a cluster, use `sync::sync()` (barrier -1, workgroup-only). */ __device__ __forceinline__ void sync() { - ::kittens::sync::sync(); + __builtin_amdgcn_s_barrier_signal(-3); + __builtin_amdgcn_s_barrier_wait(-3); } } // namespace cluster diff --git a/include/ops/warp/memory/tile/global_to_shared.cuh b/include/ops/warp/memory/tile/global_to_shared.cuh index 6f18f7f7..9e9feb91 100644 --- a/include/ops/warp/memory/tile/global_to_shared.cuh +++ b/include/ops/warp/memory/tile/global_to_shared.cuh @@ -10,6 +10,13 @@ namespace kittens { +// CDNA global→shared load overloads (4 variants): +// 1. load(ST, GL, coord) — typed ST, computes swizzle inline +// 2. load(ST, GL, coord, precomp) — typed ST, takes precomputed swizzle offsets +// 3. load_async(ST, GL, coord) — async DMA path (CDNA2+) +// 4. store(GL, ST, coord) — shared→global (reverse direction) +// For gfx1250, see the g2s:: namespace below for register-mediated and TDM paths. + template, @@ -461,6 +468,12 @@ struct lds_nopad { }; /// @brief Default LDS padding for bf16 GEMMs on gfx1250. +/// Derivation: gfx1250 LDS has 32 banks, 4 bytes wide. A 16x32 bf16 subtile +/// row is 32 * 2 = 64 bytes = 16 banks. Two consecutive rows hit the same 16 +/// banks → 2-way conflict on ds_load_b128. Padding by 8 bf16 (16 bytes = 4 +/// banks) shifts each row's bank mapping, eliminating conflicts. +/// Interval 128 = one subtile row (128 bf16 = 256 bytes). Must be power-of-2 +/// for D# encoding (see SR-09 static_assert). using lds_pad_default = lds_padded<128, 8>; namespace g2s { @@ -480,6 +493,10 @@ using i32x4_lvec = int __attribute__((__vector_size__(16))) __attribute__((addr */ template __device__ __forceinline__ int subtile_flat(int flat) { + static_assert((SUB_ROWS * SUB_COLS & (SUB_ROWS * SUB_COLS - 1)) == 0, + "sub_elems must be power-of-2 to avoid integer division"); + static_assert((SUB_COLS & (SUB_COLS - 1)) == 0, + "SUB_COLS must be power-of-2 to avoid integer division"); constexpr int sub_elems = SUB_ROWS * SUB_COLS; constexpr int subs_per_row = COLS / SUB_COLS; const int subtile_id = flat / sub_elems; @@ -501,6 +518,8 @@ __device__ __forceinline__ int subtile_flat(int flat) { * Plain `global_load` -> VGPR -> `ds_store` path. Use this when no async * intrinsic is available or for correctness baselines. The `Pad` parameter * controls the per-element LDS placement; pass `lds_nopad` for flat layouts. + * + * Caller must ensure matrix dimensions are multiples of ROWS/COLS (no bounds clamping). */ template> @@ -536,6 +555,8 @@ __device__ inline void load(T* __restrict__ lds_dst, const GL& src, const COORD& * issues one 16-byte transfer; the warp covers `8 * N_THREADS` elements per * iteration. Drain with `kittens::sync::wait_async()` before consuming. * + * Caller must ensure matrix dimensions are multiples of ROWS/COLS (no bounds clamping). + * * @tparam Pad LDS padding descriptor. * @tparam ROWS,COLS Tile shape (elements). * @tparam N_THREADS Number of threads participating in the load. @@ -645,6 +666,9 @@ __device__ __forceinline__ void build_tdm_d_2d( : (sizeof(T) == 4) ? 2 : 3; constexpr uint32_t pad_enable = (Pad::interval > 0) ? 1u : 0u; + static_assert(Pad::interval == 0 || + __builtin_popcount(Pad::interval * sizeof(T) / 4) == 1, + "Pad interval in DWords must be a power of 2 for D# encoding"); constexpr uint32_t pad_int_enc = (Pad::interval > 0) ? ( __builtin_ctz(Pad::interval * sizeof(T) / 4) ) : 0; constexpr uint32_t pad_amt_enc = (Pad::amount > 0) @@ -669,6 +693,7 @@ __device__ __forceinline__ void build_tdm_d_2d( const uint32_t tiledim1 = static_cast(ROWS); // barrier_addr occupies w1[15:0]; tensor_dim0 lo16 occupies w1[31:16]. + assert(bar_lds_addr == 0 || bar_lds_addr < 0x10000u); uint32_t w1 = (bar_lds_addr & 0xFFFFu) | (tdim0 << 16); uint32_t w2 = (tdim0 >> 16) | (tdim1 << 16); uint32_t w3 = (tdim1 >> 16) | (tiledim0 << 16); diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index 8c905aaf..f39a6638 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -709,6 +709,10 @@ namespace detail { inline constexpr int GFX1250_SUB_ROWS = 16; inline constexpr int GFX1250_SUB_COLS = 32; inline constexpr int GFX1250_SUB_ELEMS = GFX1250_SUB_ROWS * GFX1250_SUB_COLS; + +__device__ __forceinline__ int gfx1250_lane_offset(int sub_id, int row, int half) { + return sub_id * GFX1250_SUB_ELEMS + row * GFX1250_SUB_COLS + half * 16; +} } // namespace detail /** @@ -729,6 +733,8 @@ __device__ inline void load_b128( rt_bf& dst, const bf16* __restrict__ warp_lds_base) { + static_assert(Pad::amount == 0 || Pad::amount * sizeof(bf16) % 16 == 0, + "Pad amount must be a multiple of 16 bytes for ds_load_b128 alignment"); constexpr int height = WARP_M / detail::GFX1250_SUB_ROWS; constexpr int width = WARP_K / detail::GFX1250_SUB_COLS; constexpr int subs_per_row = WARP_K / detail::GFX1250_SUB_COLS; @@ -742,19 +748,15 @@ __device__ inline void load_b128( #pragma unroll for (int tj = 0; tj < width; tj++) { const int sub_id = ti * subs_per_row + tj; - const int base_flat = sub_id * detail::GFX1250_SUB_ELEMS - + row * detail::GFX1250_SUB_COLS - + half * 16; - const int padded_off = Pad::padded(base_flat); + const int padded_off = Pad::padded(detail::gfx1250_lane_offset(sub_id, row, half)); const uint32_t addr = static_cast( reinterpret_cast(warp_lds_base + padded_off)); float4 lo, hi; - asm volatile("ds_load_b128 %0, %1 offset:0\n" - : "=v"(lo) : "v"(addr) : "memory"); - asm volatile("ds_load_b128 %0, %1 offset:16\n" - : "=v"(hi) : "v"(addr) : "memory"); + asm volatile("ds_load_b128 %0, %2 offset:0\n" + "ds_load_b128 %1, %2 offset:16\n" + : "=v"(lo), "=v"(hi) : "v"(addr) : "memory"); bf16_2* lo_p = reinterpret_cast(&lo); bf16_2* hi_p = reinterpret_cast(&hi); @@ -783,6 +785,7 @@ __device__ inline void load_b32( rt_bf& dst, const bf16* __restrict__ warp_lds_base) { + // Unpadded only — use load_b128 for padded LDS layouts. constexpr int height = WARP_M / detail::GFX1250_SUB_ROWS; constexpr int width = WARP_K / detail::GFX1250_SUB_COLS; constexpr int subs_per_row = WARP_K / detail::GFX1250_SUB_COLS; @@ -795,13 +798,10 @@ __device__ inline void load_b32( for (int ti = 0; ti < height; ti++) { #pragma unroll for (int tj = 0; tj < width; tj++) { - const int sub_id = ti * subs_per_row + tj; - const int base_flat = sub_id * detail::GFX1250_SUB_ELEMS - + row * detail::GFX1250_SUB_COLS - + half * 16; + const int sub_id = ti * subs_per_row + tj; const bf16_2* lds_p = reinterpret_cast( - warp_lds_base + base_flat); + warp_lds_base + detail::gfx1250_lane_offset(sub_id, row, half)); #pragma unroll for (int k = 0; k < 8; k++) { diff --git a/include/ops/warp/register/tile/mma.cuh b/include/ops/warp/register/tile/mma.cuh index c4aca036..cf4421bd 100644 --- a/include/ops/warp/register/tile/mma.cuh +++ b/include/ops/warp/register/tile/mma.cuh @@ -124,7 +124,7 @@ __device__ static inline void mfma323232( float2 (&D)[8], typedef __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16 bf16x8_t; typedef __attribute__((__vector_size__(16 * sizeof(float)))) float floatx16_t; - *(floatx16_t*)C = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + floatx16_t acc = __builtin_amdgcn_mfma_f32_32x32x16_bf16( *(bf16x8_t*)A, *(bf16x8_t*)B, *(floatx16_t*)C, @@ -134,7 +134,7 @@ __device__ static inline void mfma323232( float2 (&D)[8], *(floatx16_t*)D = __builtin_amdgcn_mfma_f32_32x32x16_bf16( *(bf16x8_t*)(A + 4), *(bf16x8_t*)(B + 4), - *(floatx16_t*)C, + acc, 0, 0, 0 ); } @@ -220,12 +220,24 @@ __device__ static inline void mma_AB_base(rt_base && +#ifdef KITTENS_UDNA1 + // gfx1250 WMMA always computes A × B_input^T. For mma_AB, B is col-major, + // so B_input^T = B_row = the non-transposed view. Same WMMA instruction. + if constexpr (std::is_same_v && + A_rows == 16 && A_cols == 32 && + B_rows == 32 && B_cols == 16 && + std::is_same_v) { + wmma161632(d.data, a.data, b.data, c.data); + } else { + static_assert(false, "Unsupported shape combination for gfx1250 mma_AB_base"); + } +#else + if constexpr (std::is_same_v && A_rows == 16 && A_cols == 32 && B_rows == 32 && B_cols == 16 && std::is_same_v) { mfma161632(d.data, a.data, b.data, c.data); - } else if constexpr (std::is_same_v && + } else if constexpr (std::is_same_v && A_rows == 32 && A_cols == 16 && B_rows == 16 && B_cols == 32 && std::is_same_v) { @@ -233,6 +245,7 @@ __device__ static inline void mma_AB_base(rt_base(); wait_ds<0>(); } +/* ---------- NAMED BARRIERS (IDs 1-16) ---------- */ +// +// Hardware-managed subset-of-waves barriers. Unlike the workgroup barrier (-1) +// which syncs ALL waves, a named barrier only syncs waves that have JOINed it. +// Waves can be a member of at most one named barrier at a time. +// +// Usage pattern: +// 1. One wave per barrier calls init_named(id, member_count) +// 2. All participating waves call join(id) +// 3. Sync via signal(id) + wait(id) +// 4. When done, call leave(id) + +/// @brief Initialize named barrier `id` with `member_count` waves. +__device__ __forceinline__ void init_named(int id, int member_count) { + unsigned m0_val = (static_cast(member_count) << 16) | static_cast(id); + asm volatile("s_mov_b32 m0, %0\n" + "s_barrier_init m0" + :: "s"(m0_val) : "memory"); +} + +/// @brief Current wave joins named barrier `id`. +template +__device__ __forceinline__ void join() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + asm volatile("s_barrier_join %0" :: "I"(ID) : "memory"); +} + +/// @brief Signal named barrier `id`. +template +__device__ __forceinline__ void signal() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + __builtin_amdgcn_s_barrier_signal(ID); +} + +/// @brief Wait on named barrier `id`. +template +__device__ __forceinline__ void wait_named() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + __builtin_amdgcn_s_barrier_wait(ID); +} + +/// @brief Wake all waves joined to named barrier `id`. +template +__device__ __forceinline__ void wakeup_barrier() { + static_assert(ID >= 1 && ID <= 16, "named barrier ID must be in [1, 16]"); + asm volatile("s_wakeup_barrier %0" :: "I"(ID) : "memory"); +} + /* ---------- LDS BARRIER CELLS (FOR TDM / ASYNC ARRIVE) ---------- */ // // 64-bit LDS-resident barrier cell, per SP3 section 9.8.13 @@ -191,6 +240,7 @@ struct alignas(8) barrier_lds { uint64_t state; }; /// @brief Initialize an LDS barrier cell to expect `count` arrivals per phase. __device__ __forceinline__ void init_barrier(uint64_t* bar, uint32_t count) { + assert(count > 0 && count <= 0xFFFF); // pending = count - 1, phase = 0, init_count = count - 1. const uint32_t pending = count - 1; const uint32_t init_cnt = count - 1; @@ -206,6 +256,9 @@ __device__ __forceinline__ void init_barrier(uint64_t* bar, uint32_t count) { * Callers maintain a parity bit per barrier and pass it inverted before * each wait (`expected = (phase ^= 1)`). The hardware wakes sleeping * waves on phase flip; `s_sleep 1` yields the SIMD between polls. + * + * WARNING: infinite loop if the matching arrive never fires (e.g., wrong + * init_barrier count or missing async_barrier_arrive). No timeout. */ __device__ __forceinline__ void wait_barrier(uint64_t* bar, int expected_phase) { const uint32_t lds_addr = static_cast(reinterpret_cast(bar)); @@ -222,16 +275,18 @@ __device__ __forceinline__ void wait_barrier(uint64_t* bar, int expected_phase) } /** - * @brief Arrive at an LDS barrier cell from an async-ordered path. + * @brief Arrive at an LDS barrier cell from an async-ordered path (once per wave). * - * Lowers to `DS_ATOMIC_ASYNC_BARRIER_ARRIVE_B64`. Use this to manually - * arrive at a cell (the auto-arrive form is encoded in the TDM descriptor - * via `load_tdm_arrive`). + * Lowers to `DS_ATOMIC_ASYNC_BARRIER_ARRIVE_B64` which is a DS atomic that + * fires per active lane. This wrapper restricts execution to lane 0 so the + * barrier receives exactly one arrival per wave. */ __device__ __forceinline__ void async_barrier_arrive(uint64_t* lds_counter) { - uintptr_t lds_uint = reinterpret_cast(lds_counter); - __builtin_amdgcn_ds_atomic_async_barrier_arrive_b64( - reinterpret_cast(lds_uint)); + if (laneid() == 0) { + uintptr_t lds_uint = reinterpret_cast(lds_counter); + __builtin_amdgcn_ds_atomic_async_barrier_arrive_b64( + reinterpret_cast(lds_uint)); + } } } // namespace sync diff --git a/include/types/shared/st_shape.cuh b/include/types/shared/st_shape.cuh index 404ee776..4eff8911 100644 --- a/include/types/shared/st_shape.cuh +++ b/include/types/shared/st_shape.cuh @@ -296,6 +296,7 @@ struct st_16x32_padded { } } + // Named "swizzle" for API compat with CDNA shapes, but applies padding (not XOR swizzle). template __device__ __forceinline__ static const uint32_t swizzle(int2 coord) { const int r = coord.x, c = coord.y; @@ -309,18 +310,19 @@ struct st_16x32_padded { } }; +template struct is_st_16x32_padded_inst : std::false_type {}; +template struct is_st_16x32_padded_inst> : std::true_type {}; + template -concept all = std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || +concept all = std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v> || - std::is_same_v> || - std::is_same_v>; + is_st_16x32_padded_inst::value; #else template concept all = std::is_same_v || diff --git a/kernels/gemm/bf16fp32/gfx1250/Makefile b/kernels/gemm/bf16fp32/gfx1250/Makefile index 4d3ed180..c5517d73 100644 --- a/kernels/gemm/bf16fp32/gfx1250/Makefile +++ b/kernels/gemm/bf16fp32/gfx1250/Makefile @@ -28,7 +28,8 @@ HIPFLAGS := -DKITTENS_UDNA1 \ CXX_STD := c++20 CPPFLAGS += -I$(KITTENS_ROOT)/include -I$(HIP_INCLUDE_DIR) -CXXFLAGS += -std=$(CXX_STD) -O3 -w +CXXFLAGS += -std=$(CXX_STD) -O3 -Wall -Wno-unused-variable -Wno-unused-local-typedef \ + -Wno-duplicate-decl-specifier -Wno-unused-value -Wno-pass-failed M ?= 256 N ?= 256 @@ -38,11 +39,13 @@ VERIFY ?= 1 .PHONY: ladder run clean all +HK_HEADERS := $(shell find $(KITTENS_ROOT)/include -name '*.cuh' 2>/dev/null) + # Default target: build the currently selected KERNEL. -$(BIN): $(SRC) harness.h common.h +$(BIN): $(SRC) harness.h common.h $(HK_HEADERS) $(HIPCXX) $(HIPFLAGS) $(CXXFLAGS) $(CPPFLAGS) -DHARNESS_MAIN $(SRC) -o $(BIN) -all: $(BIN) +all: ladder # Loop variant -- `make ladder` rebuilds every rung. ladder: @@ -51,5 +54,9 @@ ladder: run: $(BIN) ./$(BIN) $(M) $(N) $(K) $(ITERS) $(VERIFY) +isa: $(SRC) harness.h common.h $(HK_HEADERS) + $(HIPCXX) $(HIPFLAGS) $(CXXFLAGS) $(CPPFLAGS) -DHARNESS_MAIN --save-temps $(SRC) -o $(BIN) + @echo "ISA: $(KERNEL)-hip-amdgcn-amd-amdhsa-gfx1250.s" + clean: rm -f *.out diff --git a/kernels/gemm/bf16fp32/gfx1250/common.h b/kernels/gemm/bf16fp32/gfx1250/common.h index 6eba740a..394565ec 100644 --- a/kernels/gemm/bf16fp32/gfx1250/common.h +++ b/kernels/gemm/bf16fp32/gfx1250/common.h @@ -36,7 +36,12 @@ struct gemm_globals { int M() const { return a.rows(); } int N() const { return c.cols(); } int K() const { return a.cols(); } - dim3 grid() const { return dim3(M() / BLOCK_M, N() / BLOCK_N); } + dim3 grid() const { + assert(M() % BLOCK_M == 0 && "M must be a multiple of BLOCK_M"); + assert(N() % BLOCK_N == 0 && "N must be a multiple of BLOCK_N"); + assert(K() % K_STEP == 0 && "K must be a multiple of K_STEP"); + return dim3(M() / BLOCK_M, N() / BLOCK_N); + } dim3 block() const { return dim3(NUM_THREADS); } size_t dynamic_shared_memory() const { return kittens::MAX_SHARED_MEMORY; } }; diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp index 43e7a7ad..e83e20ce 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_async.cpp @@ -59,8 +59,10 @@ void gemm_async_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); - kittens::sync::sync(); + if (k + 1 < k_iters) { // skip on last iter — no next load issued + kittens::sync::wait_async(); + kittens::sync::sync(); + } } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp index e4db519c..a444a3c9 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_double_buf.cpp @@ -56,7 +56,7 @@ void gemm_double_buf_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::sync(); + if (k + 1 < k_iters) kittens::sync::sync(); // skip on last iter — no next overwrite } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp index 97555ff4..b3dbf055 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_expert.cpp @@ -38,7 +38,7 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) const int warp_c = wid % WARPS_N; const int k_iters = K / K_STEP; - kittens::sched::expert _sched; // limited expert mode, off in dtor + kittens::sched::expert _sched{kittens::sched::mode::limited_nostall}; // DISABLE_VALU_STALL for WMMA burst kittens::g2s::load_async( A_lds[0], g.a, {0, 0, tile_m, 0}, K); @@ -50,13 +50,13 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) for (int k = 0; k < k_iters; ++k) { const int cur = k & 1, nxt = 1 - cur; + kittens::sync::arrive(); // signal before async loads — nxt != cur, no conflict if (k + 1 < k_iters) { kittens::g2s::load_async( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, K); kittens::g2s::load_async( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, K); } - kittens::sync::arrive(); rt_bf A_reg; rt_bf B_reg; @@ -69,7 +69,7 @@ void gemm_expert_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt_burst(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); + if (k + 1 < k_iters) kittens::sync::wait_async(); // skip on last iter — no next load issued } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp index a87530ae..fde01745 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_naive.cpp @@ -54,7 +54,7 @@ void gemm_naive_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::sync(); + if (k + 1 < k_iters) kittens::sync::sync(); // skip on last iter — no next overwrite } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp index 5faefd1e..46801cf2 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_padded.cpp @@ -13,6 +13,7 @@ using namespace kittens; using namespace gfx1250_gemm; +// Pad type must match between load_async (write) and load_b128 (read). using Pad = lds_pad_default; constexpr int A_ELEMS_PAD = Pad::padded_elems(BLOCK_M * K_STEP); constexpr int B_ELEMS_PAD = Pad::padded_elems(BLOCK_N * K_STEP); @@ -63,8 +64,10 @@ void gemm_padded_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); - kittens::sync::sync(); + if (k + 1 < k_iters) { // skip on last iter — no next load issued + kittens::sync::wait_async(); + kittens::sync::sync(); + } } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp index 083887b2..f2a66782 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_segment.cpp @@ -12,9 +12,16 @@ using namespace kittens; using namespace gfx1250_gemm; +// Pad type must match between load_async (write) and load_b128 (read). +// A uses ~8.7KB of a 64KB segment — the rest is unused but reserves the segment +// for dual-port LDS access (A and B on separate ports). using Pad = lds_pad_default; constexpr int A_ELEMS_PAD = Pad::padded_elems(BLOCK_M * K_STEP); constexpr int B_ELEMS_PAD = Pad::padded_elems(BLOCK_N * K_STEP); +static_assert(2 * A_ELEMS_PAD * sizeof(kittens::bf16) <= kittens::LDS_SEGMENT_BYTES, + "Double-buffered A must fit in one LDS segment"); +static_assert(2 * B_ELEMS_PAD * sizeof(kittens::bf16) <= kittens::LDS_SEGMENT_BYTES, + "Double-buffered B must fit in one LDS segment"); __global__ __launch_bounds__(NUM_THREADS, 1) void gemm_segment_kernel(const gemm_globals g, int M, int N, int K) @@ -45,13 +52,13 @@ void gemm_segment_kernel(const gemm_globals g, int M, int N, int K) for (int k = 0; k < k_iters; ++k) { const int cur = k & 1, nxt = 1 - cur; + kittens::sync::arrive(); // signal before async loads — nxt != cur, no conflict if (k + 1 < k_iters) { kittens::g2s::load_async( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, K); kittens::g2s::load_async( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, K); } - kittens::sync::arrive(); rt_bf A_reg; rt_bf B_reg; @@ -64,7 +71,7 @@ void gemm_segment_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); + if (k + 1 < k_iters) kittens::sync::wait_async(); // skip on last iter — no next load issued } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp index b4a45e02..1b98ac86 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_split_bar.cpp @@ -45,13 +45,13 @@ void gemm_split_bar_kernel(const gemm_globals g, int M, int N, int K) for (int k = 0; k < k_iters; ++k) { const int cur = k & 1, nxt = 1 - cur; + kittens::sync::arrive(); // signal before async loads — nxt != cur, no conflict if (k + 1 < k_iters) { kittens::g2s::load_async( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, K); kittens::g2s::load_async( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, K); } - kittens::sync::arrive(); // signal early -- independent work below rt_bf A_reg; rt_bf B_reg; @@ -64,7 +64,7 @@ void gemm_split_bar_kernel(const gemm_globals g, int M, int N, int K) kittens::sync::wait_ds(); mma_ABt(C_acc, A_reg, B_reg, C_acc); - kittens::sync::wait_async(); + if (k + 1 < k_iters) kittens::sync::wait_async(); // skip on last iter — no next load issued } bf16* c_base = reinterpret_cast(&g.c[{0, 0, 0, 0}]); diff --git a/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp b/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp index f609dd6e..53f6ca2c 100644 --- a/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp +++ b/kernels/gemm/bf16fp32/gfx1250/gemm_tdm_arrive.cpp @@ -96,20 +96,19 @@ void gemm_tdm_arrive_kernel(const gemm_globals g, int M, int N, int K) // `load_tdm_arrive`) is also wired in the library for runtimes that // model it natively. // - // `async_barrier_arrive` is a DS atomic, so it fires per active lane: - // guard with `laneid() == 0` so each producer wave arrives exactly - // once per phase (matching the `init_barrier(.., 1)` priming above). + // `async_barrier_arrive` internally guards with `laneid() == 0` so + // each wave arrives exactly once per phase. if (wid == 0) { g2s::load_tdm( A_lds[0], g.a, {0, 0, tile_m, 0}, M, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&A_bar[0].state); + sync::async_barrier_arrive(&A_bar[0].state); } if (wid == 1) { g2s::load_tdm( B_lds[0], g.b, {0, 0, tile_n, 0}, N, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&B_bar[0].state); + sync::async_barrier_arrive(&B_bar[0].state); } for (int k = 0; k < k_iters; ++k) { @@ -120,13 +119,13 @@ void gemm_tdm_arrive_kernel(const gemm_globals g, int M, int N, int K) g2s::load_tdm( A_lds[nxt], g.a, {0, 0, tile_m, k + 1}, M, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&A_bar[nxt].state); + sync::async_barrier_arrive(&A_bar[nxt].state); } if (wid == 1) { g2s::load_tdm( B_lds[nxt], g.b, {0, 0, tile_n, k + 1}, N, K, K); sync::wait_tensor(); - if (laneid() == 0) sync::async_barrier_arrive(&B_bar[nxt].state); + sync::async_barrier_arrive(&B_bar[nxt].state); } } @@ -163,9 +162,7 @@ void dispatch(gemm_globals g) { // Same layout as `gemm_segment`/`gemm_expert` (A in seg 0, B in seg 1) // plus 4 barrier cells in seg 0. - constexpr size_t bar_bytes = 4 * sizeof(sync::barrier_lds); const size_t mem_size = LDS_SEGMENT_BYTES + 2 * B_ELEMS_PAD * sizeof(bf16); - (void)bar_bytes; hipFuncSetAttribute(reinterpret_cast(gemm_tdm_arrive_kernel), hipFuncAttributeMaxDynamicSharedMemorySize, static_cast(mem_size)); diff --git a/kernels/gemm/bf16fp32/gfx1250/harness.h b/kernels/gemm/bf16fp32/gfx1250/harness.h index 5b974121..c44e0362 100644 --- a/kernels/gemm/bf16fp32/gfx1250/harness.h +++ b/kernels/gemm/bf16fp32/gfx1250/harness.h @@ -61,8 +61,8 @@ int main(int argc, char** argv) int n_iters = (argc > 4) ? std::atoi(argv[4]) : 1; int verify = (argc > 5) ? std::atoi(argv[5]) : 1; - std::printf("gemm_naive (bf16->fp32->bf16) M=%d N=%d K=%d iters=%d verify=%d\n", - M, N, K, n_iters, verify); + std::printf("%s (bf16->fp32->bf16) M=%d N=%d K=%d iters=%d verify=%d\n", + __FILE__, M, N, K, n_iters, verify); // ---- host fp32 reference + bf16 buffers ---- std::vector A_h(M * K), B_h(N * K), C_ref(M * N); @@ -74,6 +74,9 @@ int main(int argc, char** argv) for (auto& x : B_h) x = dist(rng); for (size_t i = 0; i < A_h.size(); ++i) A_bf[i] = float_to_bf16(A_h[i]); for (size_t i = 0; i < B_h.size(); ++i) B_bf[i] = float_to_bf16(B_h[i]); + // Round-trip through bf16 so CPU reference sees same inputs as GPU + for (size_t i = 0; i < A_h.size(); ++i) A_h[i] = bf16_to_float(A_bf[i]); + for (size_t i = 0; i < B_h.size(); ++i) B_h[i] = bf16_to_float(B_bf[i]); // ---- device buffers ---- __hip_bfloat16 *A_d = nullptr, *B_d = nullptr, *C_d = nullptr; @@ -96,6 +99,7 @@ int main(int argc, char** argv) // ---- warmup + timed run ---- dispatch(g); + HIP_OK(hipGetLastError()); HIP_OK(hipDeviceSynchronize()); hipEvent_t t0, t1; @@ -132,7 +136,14 @@ int main(int argc, char** argv) mean_abs /= (M * N); std::printf(" max_abs_err=%.4f mean_abs_err=%.4f bad=%d/%d\n", max_abs, mean_abs, n_bad, M * N); - return (max_abs < 1.0 || n_bad < 10) ? 0 : 1; + // bf16 has ~2^-7 precision; error grows as ~sqrt(K) * 2^-7 per element. + // Output is bf16 so final quantization adds another 2^-7. Use 2x headroom. + double tol = 2.0 * std::sqrt(static_cast(K)) * (1.0 / 128.0); + int ret = (n_bad == 0 && max_abs < tol) ? 0 : 1; + + hipFree(A_d); hipFree(B_d); hipFree(C_d); + hipEventDestroy(t0); hipEventDestroy(t1); + return ret; } #endif // HARNESS_MAIN