diff --git a/include/common/base_types.cuh b/include/common/base_types.cuh index a56d207de..4fa02fe29 100644 --- a/include/common/base_types.cuh +++ b/include/common/base_types.cuh @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -34,6 +35,18 @@ using bf16_2 = __hip_bfloat162; * @brief Packed word of two half-precision floating-point values. */ using half_2 = __half2; +/** + * @brief float8 floating-point type. + */ +using fp8e4m3 = __hip_fp8_e4m3_fnuz; +/** + * @brief Packed word of two float8 floating-point values. + */ +using fp8e4m3_2 = __hip_fp8x2_e4m3_fnuz; +/** + * @brief Packed word of four float8 floating-point values. + */ +using fp8e4m3_4 = __hip_fp8x4_e4m3_fnuz; namespace ducks { /** @@ -44,9 +57,9 @@ namespace ducks { namespace base_types { template -concept T2 = std::is_same_v || std::is_same_v || std::is_same_v; +concept T2 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; template -concept T1 = std::is_same_v || std::is_same_v || std::is_same_v; +concept T1 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; } // namespace base_types } // namespace ducks @@ -115,6 +128,18 @@ template<> struct constants { static __device__ inline constexpr half_2 pos_infty() { return std::bit_cast(uint32_t(0x7C007C00)); } static __device__ inline constexpr half_2 neg_infty() { return std::bit_cast(uint32_t(0xFC00FC00)); } }; +template<> struct constants { + static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast(uint8_t(0x00)); } + static __device__ inline constexpr fp8e4m3 one() { return std::bit_cast(uint8_t(0x38)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e4m3_2 zero() { return std::bit_cast(uint16_t(0x0000)); } + static __device__ inline constexpr fp8e4m3_2 one() { return std::bit_cast(uint16_t(0x3838)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast(uint32_t(0x00000000)); } + static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast(uint32_t(0x38383838)); } +}; template<> struct constants { static __device__ inline constexpr int zero() { return 0; } static __device__ inline constexpr int one() { return 1; } @@ -198,6 +223,30 @@ template<> struct packing { template<> struct packing { static __device__ inline constexpr int num() { return 4; } }; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = fp8e4m3; + using packed_type = fp8e4m3_4; +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } + using unpacked_type = fp8e4m3; + using packed_type = fp8e4m3_4; +}; + +/** + * @brief Pack four float8 into 32-bits. + */ +static __host__ __device__ inline fp8e4m3_4 make_fp8e4m3_4(const fp8e4m3 & x, const fp8e4m3 & y, const fp8e4m3 & z, const fp8e4m3 & w) { + return std::bit_cast( + static_cast( + std::bit_cast(x) | + (std::bit_cast(y) << 8) | + (std::bit_cast(z) << 16) | + (std::bit_cast(w) << 24) + ) + ); +} /** * @brief Provides templated functionality to convert between different types. @@ -300,5 +349,37 @@ template<> struct convertor { return __float22half2_rn(__bfloat1622float2(u)); } }; +template<> struct convertor { + static __host__ __device__ inline fp8e4m3_4 convert(const float4& u) { + return fp8e4m3_4(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float4 convert(const fp8e4m3_4& u) { + fp8e4m3 *vals = reinterpret_cast(const_cast(&u)); + return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3])); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e4m3_2 convert(const float2& u) { + return fp8e4m3_2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float2 convert(const fp8e4m3_2& u) { + fp8e4m3 *vals = reinterpret_cast(const_cast(&u)); + return make_float2(float(vals[0]), float(vals[1])); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e4m3 convert(const float & u) { + return fp8e4m3(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float convert(const fp8e4m3 & u) { + return float(u); + } +}; } -} +} \ No newline at end of file diff --git a/include/common/util.cuh b/include/common/util.cuh index 6d96cfeca..3926cfca3 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -13,6 +13,8 @@ #include #include "base_types.cuh" +#include "../types/register/rt_layout.cuh" +#include "../types/shared/st_layout.cuh" #ifndef __forceinline__ #define __forceinline__ __attribute__((always_inline)) @@ -30,8 +32,17 @@ namespace kittens { /** * @brief Tile dimension constant. */ -template constexpr int TILE_COL_DIM = sizeof(T) == 1 ? 32 : 16; -template constexpr int TILE_ROW_DIM = 16; +template +concept all_layouts = ducks::rt_layout::all || ducks::st_layout::all; + +template +constexpr bool is_col_lt = std::is_same_v || std::is_same_v; + +template +constexpr int TILE_ROW_DIM = is_col_lt && std::is_same_v ? 32 : 16; + +template +constexpr int TILE_COL_DIM = (!is_col_lt && std::is_same_v) ? 32 : 16; /** * @brief Tile num elements constant calculated as TILE_DIM squared. diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index 400fe4a39..ced270ee6 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -41,10 +41,10 @@ __device__ inline static void load(RT &dst, const ST &src) { int row_offset, col_offset; if constexpr (std::is_same_v) { row_offset = laneid%16; - col_offset = 4*(laneid/16); + col_offset = std::is_same_v ? 8*(laneid/16) : 4*(laneid/16); } else { - row_offset = 4*(laneid/16); + row_offset = std::is_same_v ? 8*(laneid/16) : 4*(laneid/16); col_offset = laneid%16; } @@ -77,7 +77,7 @@ __device__ inline static void load(RT &dst, const ST &src) { asm volatile( "ds_read_b64 %0, %1 offset:%2\n" : "=v"(*reinterpret_cast(&dst.tiles[i][j].data[0])) - : "v"(addr), "i"(i * ST::underlying_cols * kittens::TILE_ROW_DIM * sizeof(U)) + : "v"(addr), "i"(i * ST::underlying_cols * kittens::TILE_ROW_DIM * sizeof(U)) : "memory" ); } else { @@ -91,8 +91,22 @@ __device__ inline static void load(RT &dst, const ST &src) { } } else { // handle the column-major layout - dst.tiles[i][j].data[0] = base_types::convertor::convert(U2{src[{row, col}], src[{row+1, col}]}); - dst.tiles[i][j].data[1] = base_types::convertor::convert(U2{src[{row+2, col}], src[{row+3, col}]}); + if constexpr (std::is_same_v) { + dst.tiles[i][j].data[0] = base_types::convertor::convert( + base_types::make_fp8e4m3_4( + src[{row, col}], src[{row+1, col}], src[{row+2, col}], src[{row+3, col}] + ) + ); + dst.tiles[i][j].data[1] = base_types::convertor::convert( + base_types::make_fp8e4m3_4( + src[{row+4, col}], src[{row+5 , col}], src[{row+6, col}], src[{row+7, col}] + ) + ); + } + else { + dst.tiles[i][j].data[0] = base_types::convertor::convert(U2{src[{row, col}], src[{row+1, col}]}); + dst.tiles[i][j].data[1] = base_types::convertor::convert(U2{src[{row+2, col}], src[{row+3, col}]}); + } } } } diff --git a/include/ops/warp/register/tile/mma.cuh b/include/ops/warp/register/tile/mma.cuh index fd2443b38..f9088c3fd 100644 --- a/include/ops/warp/register/tile/mma.cuh +++ b/include/ops/warp/register/tile/mma.cuh @@ -39,6 +39,19 @@ __device__ static inline void mfma161616(float2 (&D)[2], ); } +__device__ static inline void mfma161632(float2 (&D)[2], + const fp8e4m3_4 (&A)[2], + const fp8e4m3_4 (&B)[2], + const float2 (&C)[2]) { + typedef __attribute__((__vector_size__(4 * sizeof(float)))) float float4_t; + *(float4_t*)D = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + *(long*)A, + *(long*)B, + *(float4_t*)C, + 0, 0, 0 + ); +} + /** * @brief Base matrix multiply-accumulate operation for row layout. @@ -63,6 +76,13 @@ __device__ static inline void mma_AB_base(rt_base const rt_base &c) { mfma161616(d.data, a.data, b.data, c.data); } +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + mfma161632(d.data, a.data, b.data, c.data); +} + /** * @brief Base dot product operation for row layout. * @@ -86,6 +106,12 @@ __device__ static inline void mma_ABt_base(rt_base const rt_base &c) { mfma161616(d.data, a.data, b.data, c.data); } +__device__ static inline void mma_ABt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in row-major mode + const rt_base &c) { + mfma161632(d.data, a.data, b.data, c.data); +} /** * @brief Base matrix multiply-accumulate operation for row layout with transposed A. * @@ -109,6 +135,12 @@ __device__ static inline void mma_AtB_base(rt_base const rt_base &c) { mfma161616(d.data, a.data, b.data, c.data); } +__device__ static inline void mma_AtB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + mfma161632(d.data, a.data, b.data, c.data); +} /** * @brief Base matrix multiply-accumulate operation for row layout with transposed A and B. * @@ -132,6 +164,12 @@ __device__ static inline void mma_AtBt_base(rt_base &c) { mfma161616(d.data, a.data, b.data, c.data); } +__device__ static inline void mma_AtBt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + mfma161632(d.data, a.data, b.data, c.data); +} /** * @brief Matrix multiply-accumulate operation. * @@ -159,7 +197,9 @@ __device__ static inline void mma_AB(D &d, (std::is_same_v && std::is_same_v && std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) ); #pragma unroll @@ -212,7 +252,9 @@ __device__ static inline void mma_ABt(D &d, (std::is_same_v && std::is_same_v && std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) ); #pragma unroll @@ -264,7 +306,9 @@ __device__ static inline void mma_AtB(D &d, (std::is_same_v && std::is_same_v && std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) ); #pragma unroll @@ -317,7 +361,9 @@ __device__ static inline void mma_AtBt(D &d, (std::is_same_v && std::is_same_v && std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) ); #pragma unroll diff --git a/include/ops/warp/shared/tile/conversions.cuh b/include/ops/warp/shared/tile/conversions.cuh index a87c22661..5343a38ac 100644 --- a/include/ops/warp/shared/tile/conversions.cuh +++ b/include/ops/warp/shared/tile/conversions.cuh @@ -51,12 +51,13 @@ __device__ static inline void copy(st &dst, const st __device__ inline st_subtile subtile_inplace(ST &src, int2 rowcol, bool unformatted = false) { using T = typename ST::dtype; - static_assert(subtile_rows % TILE_ROW_DIM == 0); - static_assert(subtile_cols % TILE_COL_DIM == 0); - static_assert(ST::height % (subtile_rows/TILE_ROW_DIM) == 0); - static_assert(ST::width % (subtile_cols/TILE_COL_DIM) == 0); + using layout = typename ST::layout; + static_assert(subtile_rows % TILE_ROW_DIM == 0); + static_assert(subtile_cols % TILE_COL_DIM == 0); + static_assert(ST::height % (subtile_rows/TILE_ROW_DIM) == 0); + static_assert(ST::width % (subtile_cols/TILE_COL_DIM) == 0); static_assert(ST::height == ST::underlying_height && ST::width == ST::underlying_width); // must be a real ST, no recursive subtiles. return st_subtile(src, rowcol); } -} // namespace kittens \ No newline at end of file +} // namespace kittens diff --git a/include/types/register/rt_base.cuh b/include/types/register/rt_base.cuh index 44c816f4c..56b75019b 100644 --- a/include/types/register/rt_base.cuh +++ b/include/types/register/rt_base.cuh @@ -52,12 +52,12 @@ template struct rt_base { using dtype = T2; ///< Data type of the matrix elements static_assert( - std::is_same_v || std::is_same_v || std::is_same_v, + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, "rt_base was provided an unsupported type." ); - static constexpr int tile_size_row = kittens::TILE_ROW_DIM; - static constexpr int tile_size_col = kittens::TILE_COL_DIM; + static constexpr int tile_size_row = kittens::TILE_ROW_DIM; + static constexpr int tile_size_col = kittens::TILE_COL_DIM; static constexpr int rows = tile_size_row; ///< Number of rows. static constexpr int cols = tile_size_col; ///< Number of cols. static constexpr int num_elements = rows*cols; diff --git a/include/types/shared/st.cuh b/include/types/shared/st.cuh index 9b8825796..e316dd728 100644 --- a/include/types/shared/st.cuh +++ b/include/types/shared/st.cuh @@ -36,7 +36,8 @@ namespace kittens { template< typename ST, int _subtile_height, - int _subtile_width + int _subtile_width, + ducks::st_layout::all _layout > struct st_subtile; @@ -47,32 +48,37 @@ namespace kittens { * @tparam _rows The height of the tile. * @tparam _cols The width of the tile. */ - template + template struct KITTENS_DEFAULT_ALIGN st { using identifier = ducks::st::identifier; ///< Type identifier for shared memory tile. using T = base_types::packing<_T>::unpacked_type; using T2 = base_types::packing<_T>::packed_type; using dtype = T; ///< Data type of the elements in the tile. + using layout = _layout; // define underlying data as same as that projected, to make clear that this is *not* a subtile. static constexpr int underlying_rows = _rows; static constexpr int underlying_cols = _cols; - static constexpr int underlying_height = _rows / kittens::TILE_ROW_DIM; - static constexpr int underlying_width = _cols / kittens::TILE_COL_DIM; + static constexpr int underlying_height = _rows / kittens::TILE_ROW_DIM; + static constexpr int underlying_width = _cols / kittens::TILE_COL_DIM; static constexpr int underlying_num_elements = underlying_rows * underlying_cols; static constexpr int rows = _rows; ///< Total number of rows in the tile. - static_assert(rows % kittens::TILE_ROW_DIM == 0, "Rows must be divisible by the tile dimension"); + static_assert(rows % kittens::TILE_ROW_DIM == 0, "Rows must be divisible by the tile dimension"); static constexpr int cols = _cols; ///< Total number of cols in the tile. - static_assert(cols % kittens::TILE_COL_DIM == 0, "Cols must be divisible by the tile dimension"); - static constexpr int height = _rows / kittens::TILE_ROW_DIM; ///< Height of the tile in terms of 16-element subtiles. - static constexpr int width = _cols / kittens::TILE_COL_DIM; ///< Width of the tile in terms of 16-element subtiles. + static_assert(cols % kittens::TILE_COL_DIM == 0, "Cols must be divisible by the tile dimension"); + static constexpr int height = _rows / kittens::TILE_ROW_DIM; ///< Height of the tile in terms of 16-element subtiles. + static constexpr int width = _cols / kittens::TILE_COL_DIM; ///< Width of the tile in terms of 16-element subtiles. static constexpr int num_elements = rows * cols; ///< Total number of elements in the tile. static_assert(base_types::packing::num() == 1); // must be a 1-packed type (e.g. float, bf16, etc) static constexpr int swizzle_bytes = ( sizeof(dtype) == 1 ? ( + std::is_same_v ? ( + underlying_width%4 == 0 ? 64 : + underlying_width%2 == 0 ? 32 : 16 + ) : underlying_width%4 == 0 ? 128 : underlying_width%2 == 0 ? 64 : 32 ) : @@ -128,7 +134,7 @@ namespace kittens { using col_vec = sv; ///< Column vector type for this tile using row_vec = sv; ///< Row vector type for this tile template using subtile = st_subtile< - st, subtile_rows, subtile_cols + st, subtile_rows, subtile_cols, _layout >; ///< A templated subtile type wrapper for this tile. }; @@ -146,7 +152,8 @@ namespace kittens { template< typename _ST, int _subtile_rows, - int _subtile_cols + int _subtile_cols, + ducks::st_layout::all _layout=_ST::layout > struct st_subtile { using identifier = ducks::st::identifier; // i quack like an st, gcc will never know the difference @@ -154,21 +161,22 @@ namespace kittens { using T = ST::T; using T2 = ST::T2; using dtype = T; ///< Data type of the elements in the tile. + using layout = _layout; static constexpr int underlying_rows = ST::underlying_rows; - static_assert(underlying_rows % kittens::TILE_ROW_DIM == 0, "Underlying rows must be divisible by the tile dimension"); + static_assert(underlying_rows % kittens::TILE_ROW_DIM == 0, "Underlying rows must be divisible by the tile dimension"); static constexpr int underlying_cols = ST::underlying_cols; - static_assert(underlying_cols % kittens::TILE_COL_DIM == 0, "Underlying cols must be divisible by the tile dimension"); + static_assert(underlying_cols % kittens::TILE_COL_DIM == 0, "Underlying cols must be divisible by the tile dimension"); static constexpr int underlying_height = ST::underlying_height; static constexpr int underlying_width = ST::underlying_width; static constexpr int underlying_num_elements = ST::underlying_num_elements; static constexpr int rows = _subtile_rows; - static_assert(rows % kittens::TILE_ROW_DIM == 0, "Rows must be divisible by the tile dimension"); + static_assert(rows % kittens::TILE_ROW_DIM == 0, "Rows must be divisible by the tile dimension"); static constexpr int cols = _subtile_cols; - static_assert(cols % kittens::TILE_COL_DIM == 0, "Cols must be divisible by the tile dimension"); - static constexpr int height = rows / kittens::TILE_ROW_DIM; - static constexpr int width = cols / kittens::TILE_COL_DIM; + static_assert(cols % kittens::TILE_COL_DIM == 0, "Cols must be divisible by the tile dimension"); + static constexpr int height = rows / kittens::TILE_ROW_DIM; + static constexpr int width = cols / kittens::TILE_COL_DIM; static constexpr int num_elements = rows * cols; static constexpr int swizzle_bytes = ST::swizzle_bytes; diff --git a/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/256_256_64_32.cpp b/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/256_256_64_32.cpp new file mode 100644 index 000000000..84f5c1b97 --- /dev/null +++ b/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/256_256_64_32.cpp @@ -0,0 +1,212 @@ +#include "kittens.cuh" +#include "pyutils/pyutils.cuh" +using namespace kittens; + +#define NUM_WARPS 8 +#define M 8192 +#define N 8192 +#define K 8192 + +constexpr int BLOCK_M = 256; +constexpr int BLOCK_N = 256; +constexpr int BLOCK_K = 128; +constexpr int REG_MN = 64; +constexpr int REG_K = 32; + +using G = kittens::group; +using _gl_A = gl; +using _gl_B = gl; +using _gl_C = gl; + +struct micro_globals { + _gl_A A; + _gl_B B; + _gl_C C; + hipStream_t stream; + dim3 grid() { return dim3((N / BLOCK_N) * (M / BLOCK_M)); } + dim3 block() { return dim3(NUM_WARPS * WARP_THREADS); } + size_t dynamic_shared_memory() { return 65536; } +}; + +__global__ __launch_bounds__(NUM_WARPS * WARP_THREADS, 2) // launch_bounds(max_threads_per_block, min_warps_per_simd) +void micro_tk(const micro_globals g) { + constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; + extern __shared__ alignment_dummy __shm[]; + shared_allocator al((int*)&__shm[0]); + + auto (&As) = al.allocate>(); + auto (&Bs) = al.allocate>(); + rt tiles[8]; + rt_fl C_accum[2]; + for (int i = 0; i < 2; i++) { zero(C_accum[i]); } + + int wgid = (blockIdx.y * gridDim.x) + blockIdx.x; + const int NUM_WGS = gridDim.x * gridDim.y; + constexpr int WGM = 4; + wgid = chiplet_transform_chunked(wgid, NUM_WGS, NUM_XCDS, WGM*WGM); + const int num_pid_m = ceil_div(M, BLOCK_M); + const int num_pid_n = ceil_div(N, BLOCK_N); + int num_wgid_in_group = WGM * num_pid_n; + int group_id = wgid / num_wgid_in_group; + int first_pid_m = group_id * WGM; + int group_size_m = min(num_pid_m - first_pid_m, WGM); + int output_m = first_pid_m + ((wgid % num_wgid_in_group) % group_size_m); + int output_n = (wgid % num_wgid_in_group) / group_size_m; + + const int warp_id = warpid(); + const int warp_row = warp_id / 4, warp_col = warp_id % 4; + const int k_iters = g.A.cols() / BLOCK_K; + + G::load(As, g.A, {0, 0, output_m, 0}); + G::load(Bs, g.B, {0, 0, output_n, 0}); + __builtin_amdgcn_s_barrier(); + + if (warp_row == 1) { + __builtin_amdgcn_s_barrier(); + } + + for (int K_TILE = 0; K_TILE < k_iters - 1; ++K_TILE) { + constexpr int BUFFER_SIZE_A = (BLOCK_M * BLOCK_K) / NUM_THREADS / sizeof(float4) / sizeof(fp8e4m3); + constexpr int BUFFER_SIZE_B = (BLOCK_N * BLOCK_K) / NUM_THREADS / sizeof(float4) / sizeof(fp8e4m3); + float4 a_buffer_next[BUFFER_SIZE_A]; + float4 b_buffer_next[BUFFER_SIZE_B]; + + // Cluster 0 + load_global_to_register_buffer<2, false, NUM_THREADS>(a_buffer_next, BUFFER_SIZE_A, g.A, {0, 0, output_m, K_TILE + 1}, As); + load(tiles[1], subtile_inplace(As, {warp_row, 0})); + load(tiles[2], subtile_inplace(As, {warp_row + 2, 0})); + load(tiles[0], subtile_inplace(Bs, {warp_col, 0})); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 1 + asm volatile("s_waitcnt lgkmcnt(0)"); + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[1], tiles[0], C_accum[0]); + mma_ABt(C_accum[1], tiles[2], tiles[0], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 2 + load(tiles[3], subtile_inplace(Bs, {warp_col, 1})); + load(tiles[4], subtile_inplace(As, {warp_row, 1})); + load(tiles[5], subtile_inplace(As, {warp_row + 2, 1})); + load(tiles[0], subtile_inplace(Bs, {warp_col, 2})); + load(tiles[1], subtile_inplace(As, {warp_row, 2})); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 3 + asm volatile("s_waitcnt lgkmcnt(0)"); + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[4], tiles[3], C_accum[0]); + mma_ABt(C_accum[1], tiles[5], tiles[3], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 4 + load_global_to_register_buffer<2, false, NUM_THREADS>(b_buffer_next, BUFFER_SIZE_B, g.B, {0, 0, output_n, K_TILE + 1}, Bs); + load(tiles[2], subtile_inplace(As, {warp_row + 2, 2})); + load(tiles[6], subtile_inplace(Bs, {warp_col, 3})); + load(tiles[7], subtile_inplace(As, {warp_row, 3})); + load(tiles[5], subtile_inplace(As, {warp_row + 2, 3})); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 5 + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[1], tiles[0], C_accum[0]); + mma_ABt(C_accum[1], tiles[2], tiles[0], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 6 + asm volatile("s_waitcnt lgkmcnt(0)"); + store_register_buffer_to_shared(As, a_buffer_next); + store_register_buffer_to_shared(Bs, b_buffer_next); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + // Cluster 7 + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[7], tiles[6], C_accum[0]); + mma_ABt(C_accum[1], tiles[5], tiles[6], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + // epilogue + __builtin_amdgcn_sched_barrier(0); + load(tiles[0], subtile_inplace(Bs, {warp_col, 0})); + load(tiles[1], subtile_inplace(As, {warp_row, 0})); + load(tiles[2], subtile_inplace(As, {warp_row + 2, 0})); + asm volatile("s_waitcnt lgkmcnt(0)"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[1], tiles[0], C_accum[0]); + mma_ABt(C_accum[1], tiles[2], tiles[0], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + load(tiles[3], subtile_inplace(Bs, {warp_col, 1})); + load(tiles[4], subtile_inplace(As, {warp_row, 1})); + load(tiles[5], subtile_inplace(As, {warp_row + 2, 1})); + asm volatile("s_waitcnt lgkmcnt(0)"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[4], tiles[3], C_accum[0]); + mma_ABt(C_accum[1], tiles[5], tiles[3], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + load(tiles[0], subtile_inplace(Bs, {warp_col, 2})); + load(tiles[1], subtile_inplace(As, {warp_row, 2})); + load(tiles[2], subtile_inplace(As, {warp_row + 2, 2})); + load(tiles[3], subtile_inplace(Bs, {warp_col, 3})); + load(tiles[4], subtile_inplace(As, {warp_row, 3})); + load(tiles[5], subtile_inplace(As, {warp_row + 2, 3})); + asm volatile("s_waitcnt lgkmcnt(0)"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[1], tiles[0], C_accum[0]); + mma_ABt(C_accum[1], tiles[2], tiles[0], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + __builtin_amdgcn_s_setprio(1); + mma_ABt(C_accum[0], tiles[4], tiles[3], C_accum[0]); + mma_ABt(C_accum[1], tiles[5], tiles[3], C_accum[1]); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + if (warp_row == 0) { + __builtin_amdgcn_s_barrier(); + } + store(g.C, C_accum[0], {0, 0, output_m * 4 + warp_row, output_n * 4 + warp_col}); + store(g.C, C_accum[1], {0, 0, output_m * 4 + warp_row + 2, output_n * 4 + warp_col}); +} + +void dispatch_micro(micro_globals g) { + unsigned long mem_size = g.dynamic_shared_memory(); + hipFuncSetAttribute((void*)micro_tk, hipFuncAttributeMaxDynamicSharedMemorySize, mem_size); + micro_tk<<>>(g); +} + +PYBIND11_MODULE(tk_kernel, m) { + m.doc() = "tk_kernel python module"; + py::bind_kernel(m, "micro_tk", µ_globals::A, µ_globals::B, µ_globals::C); + py::bind_function(m, "dispatch_micro", µ_globals::A, µ_globals::B, µ_globals::C); +} diff --git a/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/Makefile b/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/Makefile new file mode 100644 index 000000000..e8d39cf16 --- /dev/null +++ b/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/Makefile @@ -0,0 +1,56 @@ +# Compiler +GPU_TARGET=CDNA3 + +TARGET=tk_kernel + +# ThunderKittens root (override with env) +THUNDERKITTENS_ROOT ?= $(abspath ../../../../../) + +SRC=256_256_64_32.cpp + +# HIP variables +ROCM_INSTALL_DIR := $(ROCM_PATH) +HIP_INCLUDE_DIR := $(ROCM_INSTALL_DIR)/include/hip +ROCM_INCLUDE_DIR := $(ROCM_INSTALL_DIR)/include + +HIPCXX ?= $(ROCM_INSTALL_DIR)/bin/hipcc + +# Compiler flags based on GPU target +ifeq ($(GPU_TARGET),CDNA2) +HIPFLAGS+= -DKITTENS_CDNA2 --offload-arch=gfx90a +else ifeq ($(GPU_TARGET),CDNA3) +HIPFLAGS+= -DKITTENS_CDNA3 --offload-arch=gfx942 +endif + +# Common variables and flags +CXX_STD := c++20 +ICXXFLAGS := -std=$(CXX_STD) +ICPPFLAGS := -I${THUNDERKITTENS_ROOT}/include -I$(ROCM_INCLUDE_DIR) -I$(HIP_INCLUDE_DIR) +ILDFLAGS := +ILDLIBS := + +CXXFLAGS ?= -Wall -Wextra +CXXFLAGS := -w + +ICXXFLAGS += $(CXXFLAGS) +ICPPFLAGS += $(CPPFLAGS) +ILDFLAGS += $(LDFLAGS) +ILDLIBS += $(LDLIBS) + +ICXXFLAGS+= -I${THUNDERKITTENS_ROOT}/include -I${THUNDERKITTENS_ROOT}/prototype $(shell python3 -m pybind11 --includes) $(shell python3-config --ldflags) -shared -fPIC -Rpass-analysis=kernel-resource-usage + + +# Default target +all: $(TARGET) + +# LOGDIR := /workdir/data_logs/$(shell date +%m%d_%H%M%S)_outputs +# LOGFILE := $(LOGDIR)/make_build.log + +$(TARGET): $(SRC) + # @mkdir -p $(LOGDIR) + $(HIPCXX) $(SRC) $(HIPFLAGS) $(ICXXFLAGS) $(ICPPFLAGS) $(ILDFLAGS) \ + -o $(TARGET)$(shell python3-config --extension-suffix) 2>&1 # | tee $(LOGFILE) + +# Clean target +clean: + rm -f $(TARGET) diff --git a/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/test_python.py b/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/test_python.py new file mode 100644 index 000000000..959faa287 --- /dev/null +++ b/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/test_python.py @@ -0,0 +1,150 @@ +import torch +import tk_kernel +import random +import time + +profiling = True + +torch.manual_seed(0) +random.seed(0) + +# Inputs +N = 8192 +A = (torch.randn(N, N, dtype=torch.float32, device='cuda') / 10.0).to(torch.float8_e4m3fnuz) +B = (torch.randn(N, N, dtype=torch.float32, device='cuda') / 10.0).to(torch.float8_e4m3fnuz) +Bt = B.t().contiguous() + + +if profiling: + ############### LOGGING STUFF ############### + + import os + import time + import shutil + import re + + def parse_makefile_targets(makefile_path): + src = None + with open(makefile_path, "r") as f: + for line in f: + if match := re.match(r"^SRC\s*=\s*(\S+)", line): + src = match.group(1) + return src + + base_dir = os.path.dirname(os.path.realpath(__file__)) + + # Set destination directory + dirpath = "/workdir/data_logs/" + timestamp = time.strftime("%m%d_%H%M%S") + new_dir = os.path.join(dirpath, f"{timestamp}_outputs") + os.makedirs(new_dir, exist_ok=True) + + # Files to copy (relative to base_dir) + src_name = parse_makefile_targets(os.path.join(base_dir, "Makefile")) + print(f"src: {src_name}") + files_to_copy = [ + "Makefile", + src_name, + "tk_kernel.cpython-313-x86_64-linux-gnu.so", + "tk_kernel.cpython-312-x86_64-linux-gnu.so" + ] + + for filename in files_to_copy: + src = os.path.join(base_dir, filename) + dst = os.path.join(new_dir, filename) + if os.path.exists(src): + shutil.copy2(src, dst) + else: + print(f"Warning: {filename} not found at {src}, skipping.") + + ################ END LOGGING STUFF ############### + +if profiling: + num_warmup = 500 + num_iters = 100 +else: + num_warmup = 1 + num_iters = 0 + +start_event = torch.cuda.Event(enable_timing=True) # in milliseconds +end_event = torch.cuda.Event(enable_timing=True) +flops_ref = (2 * N**3) # FLOPs for reference + +if profiling: + # Reference matmul using PyTorch + for _ in range(num_warmup): + C_ref = torch.matmul(A.to(torch.float32), Bt.to(torch.float32)) + + timings_ref = [] + for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + C_ref = torch.matmul(A.to(torch.float32), Bt.to(torch.float32)) + + end_event.record() + torch.cuda.synchronize() + elapsed_time = start_event.elapsed_time(end_event) + timings_ref.append(elapsed_time) + # if profiling: + # print(f"{C_ref.dtype=}") + # avg_time_ref = sum(timings_ref) / len(timings_ref) + # tflops_ref = flops_ref / (avg_time_ref * 1e9) + # print(f"PyTorch reference average execution time: {avg_time_ref:.4f} ms") + # print(f"PyTorch reference performance: {tflops_ref:.2f} TFLOPS for {N}x{N} matrix multiplication.\n") + + +# Kernel matmul +C = torch.zeros(N, N, dtype=torch.float32, device='cuda') +for _ in range(num_warmup): + tk_kernel.dispatch_micro(A, B, C) +timings = [] +for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + tk_kernel.dispatch_micro(A, B, C) + end_event.record() + torch.cuda.synchronize() + elapsed_time = start_event.elapsed_time(end_event) + timings.append(elapsed_time) +if profiling: + print(f"{C.dtype=}") + avg_time = sum(timings) / len(timings) + tflops = flops_ref / (avg_time * 1e9) + print(f"Average execution time: {avg_time:.4f} ms") + print(f"Performance: {tflops:.2f} TFLOPS for {N}x{N} matrix multiplication.\n") + + +# Compare against reference +if profiling: + C_float = C.float() + C_ref_float = C_ref.float() + diff = (C_float - C_ref_float).abs() + max_error = diff.max().item() + mean_error = diff.mean().item() + error_count = (diff > 0.1).sum().item() + + print(f"Max error between kernel and reference: {max_error}") + print(f"Max error: {max_error}") + print(f"Mean error: {mean_error}") + print(f"Number of large errors (>0.1): {error_count}\n") + + ############### LOGGING OUTPUTS #################### + + data_to_log = { + "N": N, + # "avg_time_ref": avg_time_ref, + # "tflops_ref": tflops_ref, + "avg_time": avg_time, + "tflops": tflops, + "max_error": max_error, + "mean_error": mean_error, + "error_count": error_count, + } + + import json + with open(os.path.join(new_dir, "data_to_log.json"), "w") as f: + json.dump(data_to_log, f, indent=4) + + ################ END LOGGING OUTPUTS ############### + + \ No newline at end of file