Skip to content

Commit a9c7d75

Browse files
committed
Add vectorized conversion half <-> fp8
1 parent fd4eadf commit a9c7d75

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

include/kernel_float/fp8.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
2828
static constexpr bool value = true;
2929
};
3030
} // namespace detail
31+
} // namespace kernel_float
3132

3233
#define KERNEL_FLOAT_FP8_CAST(T) \
3334
namespace ops { \
@@ -60,6 +61,29 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
6061
}; \
6162
}
6263

64+
#define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \
65+
namespace detail { \
66+
template<> \
67+
struct apply_impl<ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
68+
KERNEL_FLOAT_INLINE static void call(ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
69+
__half2_raw x; \
70+
memcpy(&x, v, 2 * sizeof(T)); \
71+
__nv_fp8x2_storage_t y = __nv_cvt_halfraw2_to_fp8x2(x, __NV_NOSAT, FP8_INTERP); \
72+
memcpy(result, &y, 2 * sizeof(FP8_TY)); \
73+
} \
74+
}; \
75+
template<> \
76+
struct apply_impl<ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
77+
KERNEL_FLOAT_INLINE static void call(ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
78+
__nv_fp8x2_storage_t x; \
79+
memcpy(&x, v, 2 * sizeof(FP8_TY)); \
80+
__half2_raw y = __nv_cvt_fp8x2_to_halfraw2(x, FP8_INTERP); \
81+
memcpy(result, &y, 2 * sizeof(T)); \
82+
} \
83+
}; \
84+
}
85+
86+
namespace kernel_float {
6387
KERNEL_FLOAT_FP8_CAST(double)
6488
} // namespace kernel_float
6589

@@ -69,7 +93,11 @@ KERNEL_FLOAT_FP8_CAST(double)
6993
namespace kernel_float {
7094
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
7195
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
96+
7297
KERNEL_FLOAT_FP8_CAST(__half)
98+
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3)
99+
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
100+
73101
} // namespace kernel_float
74102
#endif // KERNEL_FLOAT_FP16_AVAILABLE
75103

@@ -79,7 +107,10 @@ KERNEL_FLOAT_FP8_CAST(__half)
79107
namespace kernel_float {
80108
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
81109
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
110+
82111
KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
112+
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
113+
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
83114
} // namespace kernel_float
84115
#endif // KERNEL_FLOAT_BF16_AVAILABLE
85116

single_include/kernel_float.h

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-03-18 16:06:55.100306
20-
// git hash: 06e08f55399e148d96070afd0ac36dd414045f04
19+
// date: 2024-04-22 13:28:09.684538
20+
// git hash: fd4eadfbb0c8597276a6c12f972038cd1baff985
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -2705,7 +2705,7 @@ struct vector_ref<T, N, const U, Align> {
27052705

27062706
#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \
27072707
template<typename T, size_t N, typename U, size_t Align, typename V> \
2708-
KERNEL_FLOAT_INLINE vector_ref<T, N> operator OP_ASSIGN( \
2708+
KERNEL_FLOAT_INLINE vector_ref<T, N, U, Align> operator OP_ASSIGN( \
27092709
vector_ref<T, N, U, Align> ptr, \
27102710
const V& value) { \
27112711
ptr.write(ptr.read() OP value); \
@@ -3379,6 +3379,7 @@ namespace kernel_float {
33793379
*/
33803380
template<typename T, typename E, class S>
33813381
struct vector: public S {
3382+
using self_type = vector<T, E, S>;
33823383
using value_type = T;
33833384
using extent_type = E;
33843385
using storage_type = S;
@@ -3577,8 +3578,8 @@ struct vector: public S {
35773578
* vec<float, 4> vec2 = select(input, indices); // [0, 40, 40, 20]
35783579
* ```
35793580
*/
3580-
template<typename V, typename... Is>
3581-
KERNEL_FLOAT_INLINE select_type<V, Is...> select(const Is&... indices) {
3581+
template<typename... Is>
3582+
KERNEL_FLOAT_INLINE select_type<self_type, Is...> select(const Is&... indices) {
35823583
return kernel_float::select(*this, indices...);
35833584
}
35843585

@@ -4255,6 +4256,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
42554256
static constexpr bool value = true;
42564257
};
42574258
} // namespace detail
4259+
} // namespace kernel_float
42584260

42594261
#define KERNEL_FLOAT_FP8_CAST(T) \
42604262
namespace ops { \
@@ -4287,6 +4289,29 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
42874289
}; \
42884290
}
42894291

4292+
#define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \
4293+
namespace detail { \
4294+
template<> \
4295+
struct apply_impl<ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
4296+
KERNEL_FLOAT_INLINE static void call(ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
4297+
__half2_raw x; \
4298+
memcpy(&x, v, 2 * sizeof(T)); \
4299+
__nv_fp8x2_storage_t y = __nv_cvt_halfraw2_to_fp8x2(x, __NV_NOSAT, FP8_INTERP); \
4300+
memcpy(result, &y, 2 * sizeof(FP8_TY)); \
4301+
} \
4302+
}; \
4303+
template<> \
4304+
struct apply_impl<ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
4305+
KERNEL_FLOAT_INLINE static void call(ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
4306+
__nv_fp8x2_storage_t x; \
4307+
memcpy(&x, v, 2 * sizeof(FP8_TY)); \
4308+
__half2_raw y = __nv_cvt_fp8x2_to_halfraw2(x, FP8_INTERP); \
4309+
memcpy(result, &y, 2 * sizeof(T)); \
4310+
} \
4311+
}; \
4312+
}
4313+
4314+
namespace kernel_float {
42904315
KERNEL_FLOAT_FP8_CAST(double)
42914316
} // namespace kernel_float
42924317

@@ -4297,6 +4322,10 @@ namespace kernel_float {
42974322
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
42984323
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
42994324
KERNEL_FLOAT_FP8_CAST(__half)
4325+
4326+
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3)
4327+
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
4328+
43004329
} // namespace kernel_float
43014330
#endif // KERNEL_FLOAT_FP16_AVAILABLE
43024331

@@ -4307,6 +4336,9 @@ namespace kernel_float {
43074336
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
43084337
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
43094338
KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
4339+
4340+
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
4341+
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
43104342
} // namespace kernel_float
43114343
#endif // KERNEL_FLOAT_BF16_AVAILABLE
43124344

0 commit comments

Comments
 (0)