diff --git a/include/common/base_types.cuh b/include/common/base_types.cuh index d9a71dc32..b19ccd40c 100644 --- a/include/common/base_types.cuh +++ b/include/common/base_types.cuh @@ -75,7 +75,7 @@ concept T2 = std::is_same_v || std::is_same_v || std::is_s || std::is_same_v; template concept T1 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v; + || std::is_same_v; } // namespace base_types } // namespace ducks @@ -262,9 +262,14 @@ template<> struct packing { using unpacked_type = fp4e2m1; using packed_type = fp4e2m1_4; }; +template<> struct packing { + static __host__ __device__ inline constexpr int num() { return 1; } + using unpacked_type = fp4e2m1_2; + using packed_type = fp4e2m1_4; +}; template<> struct packing { - static __host__ __device__ inline constexpr int num() { return 4; } - using unpacked_type = fp4e2m1; + static __host__ __device__ inline constexpr int num() { return 2; } + using unpacked_type = fp4e2m1_2; using packed_type = fp4e2m1_4; }; @@ -414,5 +419,15 @@ template<> struct convertor { return float4(u); } }; +template<> struct convertor { + static __host__ __device__ inline fp4e2m1_2 convert(const float2& u) { + return fp4e2m1_2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float2 convert(const fp4e2m1_2& u) { + return float2(u); + } +}; } } diff --git a/include/ops/warp/memory/tile/global_to_register.cuh b/include/ops/warp/memory/tile/global_to_register.cuh index dab4ecd86..616195d26 100644 --- a/include/ops/warp/memory/tile/global_to_register.cuh +++ b/include/ops/warp/memory/tile/global_to_register.cuh @@ -27,7 +27,8 @@ __device__ inline static void load(RT &dst, const GL &src, const COORD &idx) { using U = typename GL::dtype; using U2 = base_types::packing::packed_type; - static_assert(!std::is_same_v::unpacked_type, fp8e4m3>, "Unsupported type for load"); + using unpacked = typename kittens::base_types::packing::unpacked_type; + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load"); U *src_ptr = (U*)&src[(idx.template unit_coord())]; const int row_stride = src.template stride(); @@ -136,7 +137,7 @@ __device__ inline static void load(RT &dst, const GL &src, const COORD &idx) { using T2 = base_types::packing::packed_type; using U = typename GL::dtype; - static_assert(!std::is_same_v, "Unsupported type for load/store"); + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load/store"); constexpr int packing = base_types::packing::num(); @@ -301,7 +302,7 @@ __device__ inline static void store(const GL &dst, const RT &src, const COORD &i using U = typename GL::dtype; constexpr int packing = base_types::packing::num(); - static_assert(!std::is_same_v, "Unsupported type for load/store"); + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load/store"); U *dst_ptr = (U*)&dst[(idx.template unit_coord())]; const int row_stride = dst.template stride(); diff --git a/include/ops/warp/memory/tile/shared_to_register.cuh b/include/ops/warp/memory/tile/shared_to_register.cuh index 16c2a09de..27522fa65 100644 --- a/include/ops/warp/memory/tile/shared_to_register.cuh +++ b/include/ops/warp/memory/tile/shared_to_register.cuh @@ -38,6 +38,7 @@ __device__ inline static void load(RT &dst, const ST &src) { constexpr int packing = base_types::packing::num(); static_assert(std::is_same_v, "register and shared tile must have the same dtype"); + static_assert(!std::is_same_v, "fp4e2m1 scalar is not supported as a tile dtype; use fp4e2m1_2"); const int laneid = kittens::laneid(); @@ -105,6 +106,17 @@ __device__ inline static void load(RT &dst, const ST &src) { } else { static_assert(false, "Unsupported stride"); } + } else if constexpr (std::is_same_v) { + if constexpr (RT::base_tile_stride == 16) { + asm volatile( + "ds_read_b128 %0, %1 offset:%2\n" + : "=v"(*reinterpret_cast(&dst.tiles[register_row][register_col].data[idx])) + : "v"(addr), "i"(offset) + : "memory" + ); + } else { + static_assert(false, "Unsupported stride"); + } } else { static_assert(false, "Unsupported type"); } @@ -181,6 +193,13 @@ __device__ inline static void load(RT &dst, const ST &src) { : "v"(addr), "i"(offset) : "memory" ); + } else if constexpr (std::is_same_v && RT::base_tile_stride == 16) { + asm volatile( + "ds_read_b128 %0, %1 offset:%2\n" + : "=v"(*reinterpret_cast(&dst.tiles[i][j].data[idx])) + : "v"(addr), "i"(offset) + : "memory" + ); } else { static_assert(false, "Unsupported type"); } @@ -434,7 +453,7 @@ __device__ inline static void store(ST &dst, const RT &src) { using U2 = base_types::packing::packed_type; constexpr int packing = base_types::packing::num(); - static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for store"); + static_assert(!std::is_same_v && !std::is_same_v && !std::is_same_v && !std::is_same_v, "Unsupported type for store"); const int laneid = kittens::laneid(); @@ -584,7 +603,7 @@ __device__ inline static void store(ST &dst, const RT &src) { using U2 = base_types::packing::packed_type; constexpr int packing = base_types::packing::num(); - static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for store"); + static_assert(!std::is_same_v && !std::is_same_v && !std::is_same_v && !std::is_same_v, "Unsupported type for store"); const int laneid = kittens::laneid(); diff --git a/include/types/global/gl.cuh b/include/types/global/gl.cuh index 3776363d1..f983e5d34 100644 --- a/include/types/global/gl.cuh +++ b/include/types/global/gl.cuh @@ -34,6 +34,8 @@ template struct gl { using identifier = ducks::gl::identifier; + static_assert(!std::is_same_v<_T, fp4e2m1>, "For FP4 types, you must use a packed type (fp4e2m1_2 or fp4e2m1_4)."); + using T = base_types::packing<_T>::unpacked_type; using T2 = base_types::packing<_T>::packed_type; using dtype = T; diff --git a/include/types/register/rt.cuh b/include/types/register/rt.cuh index f6c7bfb23..c5fb962f7 100644 --- a/include/types/register/rt.cuh +++ b/include/types/register/rt.cuh @@ -139,5 +139,6 @@ template using rt_bf = rt; template using rt_hf = rt; template using rt_fp8e4m3 = rt; +template using rt_fp4e2m1_2 = rt; } // namespace kittens diff --git a/include/types/register/rt_base.cuh b/include/types/register/rt_base.cuh index 482ad22ff..148b51a3f 100644 --- a/include/types/register/rt_base.cuh +++ b/include/types/register/rt_base.cuh @@ -52,7 +52,7 @@ template || std::is_same_v || std::is_same_v || std::is_same_v, + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, "rt_base was provided an unsupported type." ); diff --git a/include/types/shared/st.cuh b/include/types/shared/st.cuh index ba00a912a..577659116 100644 --- a/include/types/shared/st.cuh +++ b/include/types/shared/st.cuh @@ -49,6 +49,7 @@ struct st_subtile; template struct KITTENS_DEFAULT_ALIGN st { using identifier = ducks::st::identifier; ///< Type identifier for shared memory tile. + static_assert(!std::is_same_v<_T, fp4e2m1>, "For FP4 types, you must use a packed type (fp4e2m1_2 or fp4e2m1_4)."); using T = base_types::packing<_T>::unpacked_type; using T2 = base_types::packing<_T>::packed_type; using dtype = T; ///< Data type of the elements in the tile. @@ -76,7 +77,7 @@ struct KITTENS_DEFAULT_ALIGN st { static constexpr int subtiles_per_row = cols / underlying_subtile_cols; static constexpr int subtiles_per_col = rows / underlying_subtile_rows; - static_assert(base_types::packing::num() == 1); // must be a 1-packed type (e.g. float, bf16, etc) + static_assert(base_types::packing::num() == 1 || std::is_same_v); // must be a 1-packed type (e.g. float, bf16, etc) -- fp4e2m1_2 is allowed as the canonical sub-byte tile dtype dtype data[rows*cols]; ///< Raw data storage for the tile. diff --git a/include/types/shared/sv.cuh b/include/types/shared/sv.cuh index 439c5b9e9..cf5376c59 100644 --- a/include/types/shared/sv.cuh +++ b/include/types/shared/sv.cuh @@ -91,5 +91,6 @@ template using sv_bf = sv; template using sv_hf = sv; template using sv_fl = sv; template using sv_fp8e4m3 = sv; +template using sv_fp4e2m1_2 = sv; } // namespace kittens \ No newline at end of file diff --git a/tests/unit/Makefile b/tests/unit/Makefile index 0868d1446..1ed4a10fa 100644 --- a/tests/unit/Makefile +++ b/tests/unit/Makefile @@ -26,6 +26,7 @@ HIPFLAGS+= -DTEST_INTENSITY=2 # You can also specify subsections, e.g. -DTEST_WARP_MEMORY # Or individual tests, like -DTEST_WARP_MEMORY_VEC_DSMEM. Useful for debugging! HIPFLAGS+= -DTEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER +HIPFLAGS+= -DTEST_WARP_MEMORY_TILE_FP4_LOAD ifeq ($(COMP_LEVEL),safe) HIPFLAGS+= -O0 diff --git a/tests/unit/testing_commons/testing_flags.cuh b/tests/unit/testing_commons/testing_flags.cuh index a08dea4fa..9a64b37d6 100644 --- a/tests/unit/testing_commons/testing_flags.cuh +++ b/tests/unit/testing_commons/testing_flags.cuh @@ -56,6 +56,7 @@ #define TEST_WARP_MEMORY_TILE_GLOBAL_TO_REGISTER #define TEST_WARP_MEMORY_TILE_GLOBAL_TO_SHARED #define TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER +#define TEST_WARP_MEMORY_TILE_FP4_LOAD #endif #ifdef TEST_ALL_WARP_MEMORY_VEC @@ -101,7 +102,7 @@ // Warp macros #if defined(TEST_WARP_MEMORY_TILE_GLOBAL_TO_REGISTER) || defined(TEST_WARP_MEMORY_TILE_GLOBAL_TO_SHARED) || \ - defined(TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER) + defined(TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER) || defined(TEST_WARP_MEMORY_TILE_FP4_LOAD) #define TEST_WARP_MEMORY_TILE #endif diff --git a/tests/unit/testing_commons/testing_utils.cuh b/tests/unit/testing_commons/testing_utils.cuh index c3a6f413c..6fe184b71 100644 --- a/tests/unit/testing_commons/testing_utils.cuh +++ b/tests/unit/testing_commons/testing_utils.cuh @@ -101,6 +101,30 @@ void initialize(T1 **d_i, T2 **d_o, std::vector &i_ref, std::vector) { + // FP4: i_ref is sized in pair units (matches device element count); each pair packs + // two FP4 values. We sample one float per pair and pack both halves to the same + // value, so i_ref[idx] holds the single quantized value seen by both halves. + std::vector i_t(input_size); + std::mt19937 gen(SEED); + std::uniform_real_distribution dis(-1.0, 1.0); + for (int idx = 0; idx < input_size; idx++) { + float f; + if constexpr (initializer == initializers::RANDOM) f = dis(gen); + else if constexpr (initializer == initializers::ARANGE) f = float(idx); + else f = i_ref[idx]; + i_t[idx] = fp4e2m1_2(float2{f, f}); + float2 dequant = float2(i_t[idx]); + i_ref[idx] = dequant.x; // both halves are identical + } + hipMalloc(d_i, input_size * sizeof(T1)); + hipMalloc(d_o, output_size * sizeof(T2)); + HipCheckError(); + hipMemcpy(*d_i, i_t.data(), input_size * sizeof(T1), hipMemcpyHostToDevice); + HipCheckError(); + return; + } + // Initialize matrices std::vector i_t(input_size); @@ -157,32 +181,52 @@ test_result validate(T *d_i, T *d_o, const std::vector &i_ref, std::vecto const int input_size = i_ref.size(); const int output_size = o_ref.size(); // copy back - T* o_t = new T[output_size]; float *o = new float[output_size]; hipDeviceSynchronize(); HipCheckError(); - hipMemcpy(o_t, d_o, output_size * sizeof(T), hipMemcpyDeviceToHost); - HipCheckError(); - for(int idx = 0; idx < output_size; idx++) { - if constexpr (std::is_same_v) { - o[idx] = __bfloat162float(o_t[idx]); - o_ref[idx] = __bfloat162float(__float2bfloat16(o_ref[idx])); - } - else if constexpr (std::is_same_v) { - o[idx] = __half2float(o_t[idx]); - o_ref[idx] = __half2float(__float2half(o_ref[idx])); - } - else if constexpr(std::is_same_v) { - o[idx] = o_t[idx]; - o_ref[idx] = o_ref[idx]; - } - else if constexpr (std::is_same_v) { - o[idx] = float(o_t[idx]); - o_ref[idx] = float(fp8e4m3(o_ref[idx])); + + if constexpr (std::is_same_v) { + // FP4: device buffer holds output_size pairs (one byte per pair). We compare the + // low half of each pair — matches initialize's (f,f) pack convention. + T* o_t = new T[output_size]; + hipMemcpy(o_t, d_o, output_size * sizeof(T), hipMemcpyDeviceToHost); + HipCheckError(); + for (int idx = 0; idx < output_size; idx++) { + float2 dequant = float2(o_t[idx]); + o[idx] = dequant.x; + float2 refquant = float2(fp4e2m1_2(float2{o_ref[idx], o_ref[idx]})); + o_ref[idx] = refquant.x; } - else { - assert(false && "Unsupported data type"); + delete[] o_t; + // FP4's representable grid has coarse spacing; use absolute tolerance sized to the grid. + atol = 0.5f; + rtol = 0.0f; + } else { + T* o_t = new T[output_size]; + hipMemcpy(o_t, d_o, output_size * sizeof(T), hipMemcpyDeviceToHost); + HipCheckError(); + for(int idx = 0; idx < output_size; idx++) { + if constexpr (std::is_same_v) { + o[idx] = __bfloat162float(o_t[idx]); + o_ref[idx] = __bfloat162float(__float2bfloat16(o_ref[idx])); + } + else if constexpr (std::is_same_v) { + o[idx] = __half2float(o_t[idx]); + o_ref[idx] = __half2float(__float2half(o_ref[idx])); + } + else if constexpr(std::is_same_v) { + o[idx] = o_t[idx]; + o_ref[idx] = o_ref[idx]; + } + else if constexpr (std::is_same_v) { + o[idx] = float(o_t[idx]); + o_ref[idx] = float(fp8e4m3(o_ref[idx])); + } + else { + assert(false && "Unsupported data type"); + } } + delete[] o_t; } // check std::cout << "test `" << test_name << "` "; @@ -249,7 +293,7 @@ test_result validate(T *d_i, T *d_o, const std::vector &i_ref, std::vecto } hipFree(d_i); hipFree(d_o); - delete[] o_t, o; + delete[] o; HipCheckError(); return good ? test_result::PASSED : test_result::FAILED; } \ No newline at end of file diff --git a/tests/unit/warp/memory/tile/fp4_load.cu b/tests/unit/warp/memory/tile/fp4_load.cu new file mode 100644 index 000000000..4d9c88757 --- /dev/null +++ b/tests/unit/warp/memory/tile/fp4_load.cu @@ -0,0 +1,127 @@ +#include "fp4_load.cuh" + +#ifdef TEST_WARP_MEMORY_TILE_FP4_LOAD + +#include +#include + +// Exercises the FP4 load paths (global -> shared -> register) that PR 1 adds. +// The existing sharedreg_load_store round-trip can't be used for FP4 because the +// register -> shared store path isn't in #47's scope. This test runs a hand-rolled +// kernel that loads into a register tile and then dequantizes each thread's packed +// elements into a flat float buffer. Host validation checks that every dequantized +// FP4 value appears in the expected multiset. + +using GL_fp4 = kittens::gl; +// Output buffer: 32 threads x 32 fp4e2m1_4 per thread x 4 FP4 per packed element = 4096 floats. +using GL_fl = kittens::gl; + +__global__ static void fp4_load_kernel(GL_fp4 input, GL_fl output) { + extern __shared__ kittens::alignment_dummy __shm[]; + kittens::shared_allocator<16> al((int*)&__shm[0]); + + using ST = kittens::st; + ST &shared_tile = al.allocate(); + + kittens::load<2, false, ST, GL_fp4, kittens::coord>(shared_tile, input, {0, 0, 0, 0}); + __builtin_amdgcn_s_waitcnt(0); + __builtin_amdgcn_s_barrier(); + + using RT = kittens::rt; + RT reg_tile; + kittens::load(reg_tile, shared_tile); + __builtin_amdgcn_s_waitcnt(0); + __builtin_amdgcn_s_barrier(); + + constexpr int floats_per_thread = RT::packed_per_thread * 4; + const int tid = threadIdx.x; + + #pragma unroll + for (int i = 0; i < RT::packed_per_thread; i++) { + kittens::fp4e2m1_4 packed = reg_tile.tiles[0][0].data[i]; + float4 vals = float4(packed); + const int base = tid * floats_per_thread + i * 4; + output.raw_ptr[base + 0] = vals.x; + output.raw_ptr[base + 1] = vals.y; + output.raw_ptr[base + 2] = vals.z; + output.raw_ptr[base + 3] = vals.w; + } +} + +void warp::memory::tile::fp4_load::tests(test_data &results) { + std::cout << "\n ----- Starting ops/warp/memory/tile/fp4_load tests! -----\n" << std::endl; + + constexpr int tile_pairs = 16 * 128; // device tile size in fp4e2m1_2 units + constexpr int tile_fp4 = tile_pairs * 2; // logical FP4 value count = 4096 + + std::vector i_ref(tile_pairs); + std::vector o_ref(tile_fp4); + + kittens::fp4e2m1_2 *d_i; + float *d_o; + initialize(&d_i, &d_o, i_ref, o_ref); + + GL_fp4 input_gl(d_i, nullptr, nullptr, nullptr, nullptr); + GL_fl output_gl(d_o, nullptr, nullptr, nullptr, nullptr); + + hipFuncSetAttribute( + reinterpret_cast(fp4_load_kernel), + hipFuncAttributeMaxDynamicSharedMemorySize, + kittens::MAX_SHARED_MEMORY / 2 + ); + fp4_load_kernel<<<1, kittens::WARP_THREADS, kittens::MAX_SHARED_MEMORY / 2>>>(input_gl, output_gl); + HipCheckError(); + + // Expected: each fp4e2m1_2 pair packs (f, f), so both halves dequantize to the same value. + // The kernel dumps 4096 floats total — each input pair contributes 2 identical values somewhere + // in the output. Host builds the expected multiset as 2x every i_ref entry, sorts both sides, + // and compares. + std::vector expected(tile_fp4); + for (int idx = 0; idx < tile_pairs; idx++) { + expected[2 * idx] = i_ref[idx]; + expected[2 * idx + 1] = i_ref[idx]; + } + + float *o_host = new float[tile_fp4]; + hipDeviceSynchronize(); + HipCheckError(); + hipMemcpy(o_host, d_o, tile_fp4 * sizeof(float), hipMemcpyDeviceToHost); + HipCheckError(); + + std::vector actual(o_host, o_host + tile_fp4); + std::sort(expected.begin(), expected.end()); + std::sort(actual.begin(), actual.end()); + + bool good = true; + float max_diff = 0; + int first_mismatch = -1; + for (int i = 0; i < tile_fp4; i++) { + float diff = std::abs(expected[i] - actual[i]); + if (diff > 0.5f) { + good = false; + if (first_mismatch < 0) first_mismatch = i; + if (diff > max_diff) max_diff = diff; + } + } + + std::cout << "test `fp4_load=fp4e2m1_2` "; + if (good) std::cout << " -- PASSED" << std::endl; + else { + std::cout << " ----- ALERT! FAILED (first mismatch at sorted idx " << first_mismatch + << ", max diff " << max_diff << ") -----" << std::endl; + } + + hipFree(d_i); + hipFree(d_o); + delete[] o_host; + HipCheckError(); + + test_info info; + info.label = "fp4_load=fp4e2m1_2"; + info.result = good ? test_result::PASSED : test_result::FAILED; + results.push_back(info); +} + +#endif diff --git a/tests/unit/warp/memory/tile/fp4_load.cuh b/tests/unit/warp/memory/tile/fp4_load.cuh new file mode 100644 index 000000000..3abbde3fa --- /dev/null +++ b/tests/unit/warp/memory/tile/fp4_load.cuh @@ -0,0 +1,19 @@ +#include "testing_flags.cuh" + +#ifdef TEST_WARP_MEMORY_TILE_FP4_LOAD + +#include "testing_commons.cuh" + +namespace warp { +namespace memory { +namespace tile { +namespace fp4_load { + +void tests(test_data &results); + +} +} +} +} + +#endif diff --git a/tests/unit/warp/memory/tile/global_to_shared.cu b/tests/unit/warp/memory/tile/global_to_shared.cu index b3868abb0..63688466b 100644 --- a/tests/unit/warp/memory/tile/global_to_shared.cu +++ b/tests/unit/warp/memory/tile/global_to_shared.cu @@ -13,6 +13,7 @@ struct st_load_store { static inline const std::string test_identifier = std::is_same_v ? "shared_loadstore_gmem=bf16" : std::is_same_v ? "shared_loadstore_gmem=half" : std::is_same_v ? "shared_loadstore_gmem=fp8e4m3" : + std::is_same_v ? "shared_loadstore_gmem=fp4e2m1_2" : "shared_loadstore_gmem=float"; template __host__ static void host_func(const std::vector &i_ref, std::vector &o_ref) { o_ref = i_ref; // overwrite the whole thing @@ -61,6 +62,10 @@ void test_generator(test_data &results) { g2s_sweep_size_2d_warp, RT_SHAPE, ST_SHAPE, SIZE, SIZE, I0_t>::run(results); g2s_sweep_size_2d_warp, RT_SHAPE, ST_SHAPE, SIZE, SIZE, I1_t>::run(results); g2s_sweep_size_2d_warp, RT_SHAPE, ST_SHAPE, SIZE, SIZE, I2_t>::run(results); + + g2s_sweep_size_2d_warp, RT_SHAPE, ST_SHAPE, SIZE, SIZE, I0_t>::run(results); + g2s_sweep_size_2d_warp, RT_SHAPE, ST_SHAPE, SIZE, SIZE, I1_t>::run(results); + g2s_sweep_size_2d_warp, RT_SHAPE, ST_SHAPE, SIZE, SIZE, I2_t>::run(results); } diff --git a/tests/unit/warp/memory/tile/tile.cu b/tests/unit/warp/memory/tile/tile.cu index 1691e4d0f..ecf9ec3f1 100644 --- a/tests/unit/warp/memory/tile/tile.cu +++ b/tests/unit/warp/memory/tile/tile.cu @@ -13,6 +13,9 @@ void warp::memory::tile::tests(test_data &results) { #ifdef TEST_WARP_MEMORY_TILE_SHARED_TO_REGISTER warp::memory::tile::shared_to_register::tests(results); #endif +#ifdef TEST_WARP_MEMORY_TILE_FP4_LOAD + warp::memory::tile::fp4_load::tests(results); +#endif } #endif \ No newline at end of file diff --git a/tests/unit/warp/memory/tile/tile.cuh b/tests/unit/warp/memory/tile/tile.cuh index 550a7c0a7..61fd820c4 100644 --- a/tests/unit/warp/memory/tile/tile.cuh +++ b/tests/unit/warp/memory/tile/tile.cuh @@ -7,6 +7,7 @@ #include "global_to_register.cuh" #include "global_to_shared.cuh" #include "shared_to_register.cuh" +#include "fp4_load.cuh" namespace warp { namespace memory {