Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions include/common/base_types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_s
|| std::is_same_v<T, fp4e2m1_4>;
template<typename T>
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>
|| std::is_same_v<T, fp4e2m1>;
|| std::is_same_v<T, fp4e2m1_2>;

} // namespace base_types
} // namespace ducks
Expand Down Expand Up @@ -262,9 +262,14 @@ template<> struct packing<fp4e2m1> {
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
template<> struct packing<fp4e2m1_2> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp4e2m1_2;
using packed_type = fp4e2m1_4;
};
template<> struct packing<fp4e2m1_4> {
static __host__ __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp4e2m1;
static __host__ __device__ inline constexpr int num() { return 2; }
using unpacked_type = fp4e2m1_2;
using packed_type = fp4e2m1_4;
};

Expand Down Expand Up @@ -414,5 +419,15 @@ template<> struct convertor<float4, fp4e2m1_4> {
return float4(u);
}
};
template<> struct convertor<fp4e2m1_2, float2> {
static __host__ __device__ inline fp4e2m1_2 convert(const float2& u) {
return fp4e2m1_2(u);
}
};
template<> struct convertor<float2, fp4e2m1_2> {
static __host__ __device__ inline float2 convert(const fp4e2m1_2& u) {
return float2(u);
}
};
}
}
7 changes: 4 additions & 3 deletions include/ops/warp/memory/tile/global_to_register.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ __device__ inline static void load(RT &dst, const GL &src, const COORD &idx) {
using U = typename GL::dtype;
using U2 = base_types::packing<U>::packed_type;

static_assert(!std::is_same_v<typename kittens::base_types::packing<typename RT::dtype>::unpacked_type, fp8e4m3>, "Unsupported type for load");
using unpacked = typename kittens::base_types::packing<typename RT::dtype>::unpacked_type;
static_assert(!std::is_same_v<unpacked, fp8e4m3> && !std::is_same_v<unpacked, fp4e2m1_2>, "Unsupported type for load");

U *src_ptr = (U*)&src[(idx.template unit_coord<axis, 3>())];
const int row_stride = src.template stride<axis>();
Expand Down Expand Up @@ -136,7 +137,7 @@ __device__ inline static void load(RT &dst, const GL &src, const COORD &idx) {
using T2 = base_types::packing<typename RT::dtype>::packed_type;
using U = typename GL::dtype;

static_assert(!std::is_same_v<T, fp8e4m3>, "Unsupported type for load/store");
static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<T, fp4e2m1_2>, "Unsupported type for load/store");

constexpr int packing = base_types::packing<typename RT::dtype>::num();

Expand Down Expand Up @@ -301,7 +302,7 @@ __device__ inline static void store(const GL &dst, const RT &src, const COORD &i
using U = typename GL::dtype;
constexpr int packing = base_types::packing<typename RT::dtype>::num();

static_assert(!std::is_same_v<T, fp8e4m3>, "Unsupported type for load/store");
static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<T, fp4e2m1_2>, "Unsupported type for load/store");

U *dst_ptr = (U*)&dst[(idx.template unit_coord<axis, 3>())];
const int row_stride = dst.template stride<axis>();
Expand Down
23 changes: 21 additions & 2 deletions include/ops/warp/memory/tile/shared_to_register.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
constexpr int packing = base_types::packing<typename RT::dtype>::num();

static_assert(std::is_same_v<T, U>, "register and shared tile must have the same dtype");
static_assert(!std::is_same_v<T, fp4e2m1>, "fp4e2m1 scalar is not supported as a tile dtype; use fp4e2m1_2");

const int laneid = kittens::laneid();

Expand Down Expand Up @@ -105,6 +106,17 @@ __device__ inline static void load(RT &dst, const ST &src) {
} else {
static_assert(false, "Unsupported stride");
}
} else if constexpr (std::is_same_v<U2, fp4e2m1_4>) {
if constexpr (RT::base_tile_stride == 16) {
asm volatile(
"ds_read_b128 %0, %1 offset:%2\n"
: "=v"(*reinterpret_cast<float4*>(&dst.tiles[register_row][register_col].data[idx]))
: "v"(addr), "i"(offset)
: "memory"
);
} else {
static_assert(false, "Unsupported stride");
}
} else {
static_assert(false, "Unsupported type");
}
Expand Down Expand Up @@ -181,6 +193,13 @@ __device__ inline static void load(RT &dst, const ST &src) {
: "v"(addr), "i"(offset)
: "memory"
);
} else if constexpr (std::is_same_v<U2, fp4e2m1_4> && RT::base_tile_stride == 16) {
asm volatile(
"ds_read_b128 %0, %1 offset:%2\n"
: "=v"(*reinterpret_cast<float4*>(&dst.tiles[i][j].data[idx]))
: "v"(addr), "i"(offset)
: "memory"
);
} else {
static_assert(false, "Unsupported type");
}
Expand Down Expand Up @@ -434,7 +453,7 @@ __device__ inline static void store(ST &dst, const RT &src) {
using U2 = base_types::packing<U >::packed_type;
constexpr int packing = base_types::packing<typename RT::dtype>::num();

static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<U, fp8e4m3>, "Unsupported type for store");
static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<U, fp8e4m3> && !std::is_same_v<T, fp4e2m1_2> && !std::is_same_v<U, fp4e2m1_2>, "Unsupported type for store");

const int laneid = kittens::laneid();

Expand Down Expand Up @@ -584,7 +603,7 @@ __device__ inline static void store(ST &dst, const RT &src) {
using U2 = base_types::packing<U >::packed_type;
constexpr int packing = base_types::packing<typename RT::dtype>::num();

static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<U, fp8e4m3>, "Unsupported type for store");
static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<U, fp8e4m3> && !std::is_same_v<T, fp4e2m1_2> && !std::is_same_v<U, fp4e2m1_2>, "Unsupported type for store");

const int laneid = kittens::laneid();

Expand Down
2 changes: 2 additions & 0 deletions include/types/global/gl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ template<typename _T, int b, int d, int r, int c, typename... TMA_Types>
struct gl {
using identifier = ducks::gl::identifier;

static_assert(!std::is_same_v<_T, fp4e2m1>, "For FP4 types, you must use a packed type (fp4e2m1_2 or fp4e2m1_4).");

using T = base_types::packing<_T>::unpacked_type;
using T2 = base_types::packing<_T>::packed_type;
using dtype = T;
Expand Down
1 change: 1 addition & 0 deletions include/types/register/rt.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,6 @@ template<int _r, int _c, ducks::rt_layout::all layout=ducks::rt_layout::row, duc
template<int _r, int _c, ducks::rt_layout::all layout=ducks::rt_layout::row, ducks::rt_shape::all shape=ducks::rt_shape::rt_16x16> using rt_bf = rt<bf16, _r, _c, layout, shape>;
template<int _r, int _c, ducks::rt_layout::all layout=ducks::rt_layout::row, ducks::rt_shape::all shape=ducks::rt_shape::rt_16x16> using rt_hf = rt<half, _r, _c, layout, shape>;
template<int _r, int _c, ducks::rt_layout::all layout=ducks::rt_layout::row, ducks::rt_shape::all shape=ducks::rt_shape::rt_16x128> using rt_fp8e4m3 = rt<fp8e4m3, _r, _c, layout, shape>;
template<int _r, int _c, ducks::rt_layout::all layout=ducks::rt_layout::row, ducks::rt_shape::all shape=ducks::rt_shape::rt_16x128> using rt_fp4e2m1_2 = rt<fp4e2m1_2, _r, _c, layout, shape>;

} // namespace kittens
2 changes: 1 addition & 1 deletion include/types/register/rt_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ template<typename _T, ducks::rt_layout::all _layout, ducks::rt_shape::all _shape
using dtype = T2; ///< Data type of the matrix elements

static_assert(
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2> || std::is_same_v<dtype, fp8e4m3_4>,
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2> || std::is_same_v<dtype, fp8e4m3_4> || std::is_same_v<dtype, fp4e2m1_4>,
"rt_base was provided an unsupported type."
);

Expand Down
3 changes: 2 additions & 1 deletion include/types/shared/st.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct st_subtile;
template<typename _T, int _rows, int _cols, ducks::st_shape::all _shape>
struct KITTENS_DEFAULT_ALIGN st {
using identifier = ducks::st::identifier; ///< Type identifier for shared memory tile.
static_assert(!std::is_same_v<_T, fp4e2m1>, "For FP4 types, you must use a packed type (fp4e2m1_2 or fp4e2m1_4).");
using T = base_types::packing<_T>::unpacked_type;
using T2 = base_types::packing<_T>::packed_type;
using dtype = T; ///< Data type of the elements in the tile.
Expand Down Expand Up @@ -76,7 +77,7 @@ struct KITTENS_DEFAULT_ALIGN st {
static constexpr int subtiles_per_row = cols / underlying_subtile_cols;
static constexpr int subtiles_per_col = rows / underlying_subtile_rows;

static_assert(base_types::packing<dtype>::num() == 1); // must be a 1-packed type (e.g. float, bf16, etc)
static_assert(base_types::packing<dtype>::num() == 1 || std::is_same_v<dtype, fp4e2m1_2>); // must be a 1-packed type (e.g. float, bf16, etc) -- fp4e2m1_2 is allowed as the canonical sub-byte tile dtype

dtype data[rows*cols]; ///< Raw data storage for the tile.

Expand Down
1 change: 1 addition & 0 deletions include/types/shared/sv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,6 @@ template<size_t _length> using sv_bf = sv<bf16, _length>;
template<size_t _length> using sv_hf = sv<half, _length>;
template<size_t _length> using sv_fl = sv<float, _length>;
template<size_t _length> using sv_fp8e4m3 = sv<fp8e4m3, _length>;
template<size_t _length> using sv_fp4e2m1_2 = sv<fp4e2m1_2, _length>;

} // namespace kittens
1 change: 1 addition & 0 deletions tests/unit/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ HIPFLAGS+= -DTEST_INTENSITY=2
# You can also specify subsections, e.g. -DTEST_WARP_MEMORY
# Or individual tests, like -DTEST_WARP_MEMORY_VEC_DSMEM. Useful for debugging!
HIPFLAGS+= -DTEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER
HIPFLAGS+= -DTEST_WARP_MEMORY_TILE_FP4_LOAD

ifeq ($(COMP_LEVEL),safe)
HIPFLAGS+= -O0
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/testing_commons/testing_flags.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#define TEST_WARP_MEMORY_TILE_GLOBAL_TO_REGISTER
#define TEST_WARP_MEMORY_TILE_GLOBAL_TO_SHARED
#define TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER
#define TEST_WARP_MEMORY_TILE_FP4_LOAD
#endif

#ifdef TEST_ALL_WARP_MEMORY_VEC
Expand Down Expand Up @@ -101,7 +102,7 @@
// Warp macros

#if defined(TEST_WARP_MEMORY_TILE_GLOBAL_TO_REGISTER) || defined(TEST_WARP_MEMORY_TILE_GLOBAL_TO_SHARED) || \
defined(TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER)
defined(TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER) || defined(TEST_WARP_MEMORY_TILE_FP4_LOAD)
#define TEST_WARP_MEMORY_TILE
#endif

Expand Down
88 changes: 66 additions & 22 deletions tests/unit/testing_commons/testing_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,30 @@ void initialize(T1 **d_i, T2 **d_o, std::vector<float> &i_ref, std::vector<float
const int input_size = i_ref.size();
const int output_size = o_ref.size();

if constexpr (std::is_same_v<T1, fp4e2m1_2>) {
// FP4: i_ref is sized in pair units (matches device element count); each pair packs
// two FP4 values. We sample one float per pair and pack both halves to the same
// value, so i_ref[idx] holds the single quantized value seen by both halves.
std::vector<T1> i_t(input_size);
std::mt19937 gen(SEED);
std::uniform_real_distribution<float> dis(-1.0, 1.0);
for (int idx = 0; idx < input_size; idx++) {
float f;
if constexpr (initializer == initializers::RANDOM) f = dis(gen);
else if constexpr (initializer == initializers::ARANGE) f = float(idx);
else f = i_ref[idx];
i_t[idx] = fp4e2m1_2(float2{f, f});
float2 dequant = float2(i_t[idx]);
i_ref[idx] = dequant.x; // both halves are identical
}
hipMalloc(d_i, input_size * sizeof(T1));
hipMalloc(d_o, output_size * sizeof(T2));
HipCheckError();
hipMemcpy(*d_i, i_t.data(), input_size * sizeof(T1), hipMemcpyHostToDevice);
HipCheckError();
return;
}

// Initialize matrices
std::vector<T1> i_t(input_size);

Expand Down Expand Up @@ -157,32 +181,52 @@ test_result validate(T *d_i, T *d_o, const std::vector<float> &i_ref, std::vecto
const int input_size = i_ref.size();
const int output_size = o_ref.size();
// copy back
T* o_t = new T[output_size];
float *o = new float[output_size];
hipDeviceSynchronize();
HipCheckError();
hipMemcpy(o_t, d_o, output_size * sizeof(T), hipMemcpyDeviceToHost);
HipCheckError();
for(int idx = 0; idx < output_size; idx++) {
if constexpr (std::is_same_v<T, bf16>) {
o[idx] = __bfloat162float(o_t[idx]);
o_ref[idx] = __bfloat162float(__float2bfloat16(o_ref[idx]));
}
else if constexpr (std::is_same_v<T, half>) {
o[idx] = __half2float(o_t[idx]);
o_ref[idx] = __half2float(__float2half(o_ref[idx]));
}
else if constexpr(std::is_same_v<T, float>) {
o[idx] = o_t[idx];
o_ref[idx] = o_ref[idx];
}
else if constexpr (std::is_same_v<T, fp8e4m3>) {
o[idx] = float(o_t[idx]);
o_ref[idx] = float(fp8e4m3(o_ref[idx]));

if constexpr (std::is_same_v<T, fp4e2m1_2>) {
// FP4: device buffer holds output_size pairs (one byte per pair). We compare the
// low half of each pair — matches initialize's (f,f) pack convention.
T* o_t = new T[output_size];
hipMemcpy(o_t, d_o, output_size * sizeof(T), hipMemcpyDeviceToHost);
HipCheckError();
for (int idx = 0; idx < output_size; idx++) {
float2 dequant = float2(o_t[idx]);
o[idx] = dequant.x;
float2 refquant = float2(fp4e2m1_2(float2{o_ref[idx], o_ref[idx]}));
o_ref[idx] = refquant.x;
}
else {
assert(false && "Unsupported data type");
delete[] o_t;
// FP4's representable grid has coarse spacing; use absolute tolerance sized to the grid.
atol = 0.5f;
rtol = 0.0f;
} else {
T* o_t = new T[output_size];
hipMemcpy(o_t, d_o, output_size * sizeof(T), hipMemcpyDeviceToHost);
HipCheckError();
for(int idx = 0; idx < output_size; idx++) {
if constexpr (std::is_same_v<T, bf16>) {
o[idx] = __bfloat162float(o_t[idx]);
o_ref[idx] = __bfloat162float(__float2bfloat16(o_ref[idx]));
}
else if constexpr (std::is_same_v<T, half>) {
o[idx] = __half2float(o_t[idx]);
o_ref[idx] = __half2float(__float2half(o_ref[idx]));
}
else if constexpr(std::is_same_v<T, float>) {
o[idx] = o_t[idx];
o_ref[idx] = o_ref[idx];
}
else if constexpr (std::is_same_v<T, fp8e4m3>) {
o[idx] = float(o_t[idx]);
o_ref[idx] = float(fp8e4m3(o_ref[idx]));
}
else {
assert(false && "Unsupported data type");
}
}
delete[] o_t;
}
// check
std::cout << "test `" << test_name << "` ";
Expand Down Expand Up @@ -249,7 +293,7 @@ test_result validate(T *d_i, T *d_o, const std::vector<float> &i_ref, std::vecto
}
hipFree(d_i);
hipFree(d_o);
delete[] o_t, o;
delete[] o;
HipCheckError();
return good ? test_result::PASSED : test_result::FAILED;
}
Loading