Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 84 additions & 3 deletions include/common/base_types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <hip_bf16.h>
#include <hip_fp16.h>
#include <hip_fp8.h>
#include <string>
#include <bit>

Expand All @@ -34,6 +35,18 @@ using bf16_2 = __hip_bfloat162;
* @brief Packed word of two half-precision floating-point values.
*/
using half_2 = __half2;
/**
* @brief float8 floating-point type.
*/
using fp8e4m3 = __hip_fp8_e4m3_fnuz;
/**
* @brief Packed word of two float8 floating-point values.
*/
using fp8e4m3_2 = __hip_fp8x2_e4m3_fnuz;
/**
* @brief Packed word of four float8 floating-point values.
*/
using fp8e4m3_4 = __hip_fp8x4_e4m3_fnuz;

namespace ducks {
/**
Expand All @@ -44,9 +57,9 @@ namespace ducks {
namespace base_types {

template<typename T>
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2>;
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>;
template<typename T>
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half>;
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>;

} // namespace base_types
} // namespace ducks
Expand Down Expand Up @@ -115,6 +128,18 @@ template<> struct constants<half_2> {
static __device__ inline constexpr half_2 pos_infty() { return std::bit_cast<half_2>(uint32_t(0x7C007C00)); }
static __device__ inline constexpr half_2 neg_infty() { return std::bit_cast<half_2>(uint32_t(0xFC00FC00)); }
};
template<> struct constants<fp8e4m3> {
static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast<fp8e4m3>(uint8_t(0x00)); }
static __device__ inline constexpr fp8e4m3 one() { return std::bit_cast<fp8e4m3>(uint8_t(0x38)); }
};
template<> struct constants<fp8e4m3_2> {
static __device__ inline constexpr fp8e4m3_2 zero() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x0000)); }
static __device__ inline constexpr fp8e4m3_2 one() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x3838)); }
};
template<> struct constants<fp8e4m3_4> {
static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x00000000)); }
static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x38383838)); }
};
template<> struct constants<int> {
static __device__ inline constexpr int zero() { return 0; }
static __device__ inline constexpr int one() { return 1; }
Expand Down Expand Up @@ -198,6 +223,30 @@ template<> struct packing<float4> {
template<> struct packing<int4> {
static __device__ inline constexpr int num() { return 4; }
};
template<> struct packing<fp8e4m3> {
static __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp8e4m3;
using packed_type = fp8e4m3_4;
};
template<> struct packing<fp8e4m3_4> {
static __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp8e4m3;
using packed_type = fp8e4m3_4;
};

/**
* @brief Pack four float8 into 32-bits.
*/
static __host__ __device__ inline fp8e4m3_4 make_fp8e4m3_4(const fp8e4m3 & x, const fp8e4m3 & y, const fp8e4m3 & z, const fp8e4m3 & w) {
return std::bit_cast<fp8e4m3_4>(
static_cast<uint32_t>(
std::bit_cast<uint8_t>(x) |
(std::bit_cast<uint8_t>(y) << 8) |
(std::bit_cast<uint8_t>(z) << 16) |
(std::bit_cast<uint8_t>(w) << 24)
)
);
}

/**
* @brief Provides templated functionality to convert between different types.
Expand Down Expand Up @@ -300,5 +349,37 @@ template<> struct convertor<half_2, bf16_2> {
return __float22half2_rn(__bfloat1622float2(u));
}
};
template<> struct convertor<fp8e4m3_4, float4> {
static __host__ __device__ inline fp8e4m3_4 convert(const float4& u) {
return fp8e4m3_4(u);
}
};
template<> struct convertor<float4, fp8e4m3_4> {
static __host__ __device__ inline float4 convert(const fp8e4m3_4& u) {
fp8e4m3 *vals = reinterpret_cast<fp8e4m3*>(const_cast<fp8e4m3_4*>(&u));
return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3]));
}
};
template<> struct convertor<fp8e4m3_2, float2> {
static __host__ __device__ inline fp8e4m3_2 convert(const float2& u) {
return fp8e4m3_2(u);
}
};
template<> struct convertor<float2, fp8e4m3_2> {
static __host__ __device__ inline float2 convert(const fp8e4m3_2& u) {
fp8e4m3 *vals = reinterpret_cast<fp8e4m3*>(const_cast<fp8e4m3_2*>(&u));
return make_float2(float(vals[0]), float(vals[1]));
}
};
template<> struct convertor<fp8e4m3, float> {
static __host__ __device__ inline fp8e4m3 convert(const float & u) {
return fp8e4m3(u);
}
};
template<> struct convertor<float, fp8e4m3> {
static __host__ __device__ inline float convert(const fp8e4m3 & u) {
return float(u);
}
};
}
}
}
15 changes: 13 additions & 2 deletions include/common/util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <hip/hip_runtime.h>

#include "base_types.cuh"
#include "../types/register/rt_layout.cuh"
#include "../types/shared/st_layout.cuh"

#ifndef __forceinline__
#define __forceinline__ __attribute__((always_inline))
Expand All @@ -30,8 +32,17 @@ namespace kittens {
/**
* @brief Tile dimension constant.
*/
template<typename T> constexpr int TILE_COL_DIM = sizeof(T) == 1 ? 32 : 16;
template<typename T> constexpr int TILE_ROW_DIM = 16;
template<typename T>
concept all_layouts = ducks::rt_layout::all<T> || ducks::st_layout::all<T>;

template<all_layouts layout>
constexpr bool is_col_lt = std::is_same_v<layout, ducks::rt_layout::col> || std::is_same_v<layout, ducks::st_layout::col>;

template<typename T, all_layouts layout=ducks::st_layout::row>
constexpr int TILE_ROW_DIM = is_col_lt<layout> && std::is_same_v<T, fp8e4m3> ? 32 : 16;

template<typename T, all_layouts layout=ducks::st_layout::row>
constexpr int TILE_COL_DIM = (!is_col_lt<layout> && std::is_same_v<T, fp8e4m3>) ? 32 : 16;

/**
* @brief Tile num elements constant calculated as TILE_DIM squared.
Expand Down
24 changes: 19 additions & 5 deletions include/ops/warp/memory/tile/shared_to_register.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ __device__ inline static void load(RT &dst, const ST &src) {
int row_offset, col_offset;
if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row>) {
row_offset = laneid%16;
col_offset = 4*(laneid/16);
col_offset = std::is_same_v<T, fp8e4m3> ? 8*(laneid/16) : 4*(laneid/16);
}
else {
row_offset = 4*(laneid/16);
row_offset = std::is_same_v<T, fp8e4m3> ? 8*(laneid/16) : 4*(laneid/16);
col_offset = laneid%16;
}

Expand Down Expand Up @@ -77,7 +77,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
asm volatile(
"ds_read_b64 %0, %1 offset:%2\n"
: "=v"(*reinterpret_cast<uint64_t*>(&dst.tiles[i][j].data[0]))
: "v"(addr), "i"(i * ST::underlying_cols * kittens::TILE_ROW_DIM<U> * sizeof(U))
: "v"(addr), "i"(i * ST::underlying_cols * kittens::TILE_ROW_DIM<U, typename ST::layout> * sizeof(U))
: "memory"
);
} else {
Expand All @@ -91,8 +91,22 @@ __device__ inline static void load(RT &dst, const ST &src) {
}
}
else { // handle the column-major layout
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(U2{src[{row, col}], src[{row+1, col}]});
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(U2{src[{row+2, col}], src[{row+3, col}]});
if constexpr (std::is_same_v<T, fp8e4m3>) {
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(
base_types::make_fp8e4m3_4(
src[{row, col}], src[{row+1, col}], src[{row+2, col}], src[{row+3, col}]
)
);
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(
base_types::make_fp8e4m3_4(
src[{row+4, col}], src[{row+5 , col}], src[{row+6, col}], src[{row+7, col}]
)
);
}
else {
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(U2{src[{row, col}], src[{row+1, col}]});
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(U2{src[{row+2, col}], src[{row+3, col}]});
}
}
}
}
Expand Down
54 changes: 50 additions & 4 deletions include/ops/warp/register/tile/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ __device__ static inline void mfma161616(float2 (&D)[2],
);
}

__device__ static inline void mfma161632(float2 (&D)[2],
const fp8e4m3_4 (&A)[2],
const fp8e4m3_4 (&B)[2],
const float2 (&C)[2]) {
typedef __attribute__((__vector_size__(4 * sizeof(float)))) float float4_t;
*(float4_t*)D = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
*(long*)A,
*(long*)B,
*(float4_t*)C,
0, 0, 0
);
}


/**
* @brief Base matrix multiply-accumulate operation for row layout.
Expand All @@ -63,6 +76,13 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col>
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161616(d.data, a.data, b.data, c.data);
}
__device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col> &d,
const rt_base<fp8e4m3, ducks::rt_layout::row> &a,
const rt_base<fp8e4m3, ducks::rt_layout::col> &b, // in col-major mode
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161632(d.data, a.data, b.data, c.data);
}

/**
* @brief Base dot product operation for row layout.
*
Expand All @@ -86,6 +106,12 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col>
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161616(d.data, a.data, b.data, c.data);
}
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col> &d,
const rt_base<fp8e4m3, ducks::rt_layout::row> &a,
const rt_base<fp8e4m3, ducks::rt_layout::row> &b, // in row-major mode
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161632(d.data, a.data, b.data, c.data);
}
/**
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
*
Expand All @@ -109,6 +135,12 @@ __device__ static inline void mma_AtB_base(rt_base<float, ducks::rt_layout::col>
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161616(d.data, a.data, b.data, c.data);
}
__device__ static inline void mma_AtB_base(rt_base<float, ducks::rt_layout::col> &d,
const rt_base<fp8e4m3, ducks::rt_layout::col> &a,
const rt_base<fp8e4m3, ducks::rt_layout::col> &b, // in col-major mode
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161632(d.data, a.data, b.data, c.data);
}
/**
* @brief Base matrix multiply-accumulate operation for row layout with transposed A and B.
*
Expand All @@ -132,6 +164,12 @@ __device__ static inline void mma_AtBt_base(rt_base<float, ducks::rt_layout::col
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161616(d.data, a.data, b.data, c.data);
}
__device__ static inline void mma_AtBt_base(rt_base<float, ducks::rt_layout::col> &d,
const rt_base<fp8e4m3, ducks::rt_layout::col> &a,
const rt_base<fp8e4m3, ducks::rt_layout::row> &b, // in col-major mode
const rt_base<float, ducks::rt_layout::col> &c) {
mfma161632(d.data, a.data, b.data, c.data);
}
/**
* @brief Matrix multiply-accumulate operation.
*
Expand Down Expand Up @@ -159,7 +197,9 @@ __device__ static inline void mma_AB(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
);

#pragma unroll
Expand Down Expand Up @@ -212,7 +252,9 @@ __device__ static inline void mma_ABt(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
);

#pragma unroll
Expand Down Expand Up @@ -264,7 +306,9 @@ __device__ static inline void mma_AtB(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
);

#pragma unroll
Expand Down Expand Up @@ -317,7 +361,9 @@ __device__ static inline void mma_AtBt(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
);

#pragma unroll
Expand Down
11 changes: 6 additions & 5 deletions include/ops/warp/shared/tile/conversions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ __device__ static inline void copy(st<T, _height, _width> &dst, const st<U, _hei
template<int subtile_rows, int subtile_cols, ducks::st::all ST>
__device__ inline st_subtile<ST, subtile_rows, subtile_cols> subtile_inplace(ST &src, int2 rowcol, bool unformatted = false) {
using T = typename ST::dtype;
static_assert(subtile_rows % TILE_ROW_DIM<T> == 0);
static_assert(subtile_cols % TILE_COL_DIM<T> == 0);
static_assert(ST::height % (subtile_rows/TILE_ROW_DIM<T>) == 0);
static_assert(ST::width % (subtile_cols/TILE_COL_DIM<T>) == 0);
using layout = typename ST::layout;
static_assert(subtile_rows % TILE_ROW_DIM<T, layout> == 0);
static_assert(subtile_cols % TILE_COL_DIM<T, layout> == 0);
static_assert(ST::height % (subtile_rows/TILE_ROW_DIM<T, layout>) == 0);
static_assert(ST::width % (subtile_cols/TILE_COL_DIM<T, layout>) == 0);
static_assert(ST::height == ST::underlying_height && ST::width == ST::underlying_width); // must be a real ST, no recursive subtiles.
return st_subtile<ST, subtile_rows, subtile_cols>(src, rowcol);
}

} // namespace kittens
} // namespace kittens
6 changes: 3 additions & 3 deletions include/types/register/rt_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ template<typename _T, ducks::rt_layout::all _layout> struct rt_base {
using dtype = T2; ///< Data type of the matrix elements

static_assert(
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2>,
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2> || std::is_same_v<dtype, fp8e4m3_4>,
"rt_base was provided an unsupported type."
);

static constexpr int tile_size_row = kittens::TILE_ROW_DIM<T>;
static constexpr int tile_size_col = kittens::TILE_COL_DIM<T>;
static constexpr int tile_size_row = kittens::TILE_ROW_DIM<T, layout>;
static constexpr int tile_size_col = kittens::TILE_COL_DIM<T, layout>;
static constexpr int rows = tile_size_row; ///< Number of rows.
static constexpr int cols = tile_size_col; ///< Number of cols.
static constexpr int num_elements = rows*cols;
Expand Down
Loading