16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2024-03-18 16:06:55.100306
20
- // git hash: 06e08f55399e148d96070afd0ac36dd414045f04
19
+ // date: 2024-04-22 13:28:09.684538
20
+ // git hash: fd4eadfbb0c8597276a6c12f972038cd1baff985
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
@@ -2705,7 +2705,7 @@ struct vector_ref<T, N, const U, Align> {
2705
2705
2706
2706
#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP (OP, OP_ASSIGN ) \
2707
2707
template <typename T, size_t N, typename U, size_t Align, typename V> \
2708
- KERNEL_FLOAT_INLINE vector_ref<T, N> operator OP_ASSIGN ( \
2708
+ KERNEL_FLOAT_INLINE vector_ref<T, N, U, Align > operator OP_ASSIGN ( \
2709
2709
vector_ref<T, N, U, Align> ptr, \
2710
2710
const V& value) { \
2711
2711
ptr.write (ptr.read () OP value); \
@@ -3379,6 +3379,7 @@ namespace kernel_float {
3379
3379
*/
3380
3380
template <typename T, typename E, class S >
3381
3381
struct vector : public S {
3382
+ using self_type = vector<T, E, S>;
3382
3383
using value_type = T;
3383
3384
using extent_type = E;
3384
3385
using storage_type = S;
@@ -3577,8 +3578,8 @@ struct vector: public S {
3577
3578
* vec<float, 4> vec2 = select(input, indices); // [0, 40, 40, 20]
3578
3579
* ```
3579
3580
*/
3580
- template <typename V, typename ... Is>
3581
- KERNEL_FLOAT_INLINE select_type<V , Is...> select (const Is&... indices) {
3581
+ template <typename ... Is>
3582
+ KERNEL_FLOAT_INLINE select_type<self_type , Is...> select (const Is&... indices) {
3582
3583
return kernel_float::select (*this , indices...);
3583
3584
}
3584
3585
@@ -4255,6 +4256,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
4255
4256
static constexpr bool value = true ;
4256
4257
};
4257
4258
} // namespace detail
4259
+ } // namespace kernel_float
4258
4260
4259
4261
#define KERNEL_FLOAT_FP8_CAST (T ) \
4260
4262
namespace ops { \
@@ -4287,6 +4289,29 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
4287
4289
}; \
4288
4290
}
4289
4291
4292
+ #define KERNEL_FLOAT_FP8_CAST2 (T, FP8_TY, FP8_INTERP ) \
4293
+ namespace detail { \
4294
+ template <> \
4295
+ struct apply_impl <ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
4296
+ KERNEL_FLOAT_INLINE static void call (ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
4297
+ __half2_raw x; \
4298
+ memcpy (&x, v, 2 * sizeof (T)); \
4299
+ __nv_fp8x2_storage_t y = __nv_cvt_halfraw2_to_fp8x2 (x, __NV_NOSAT, FP8_INTERP); \
4300
+ memcpy (result, &y, 2 * sizeof (FP8_TY)); \
4301
+ } \
4302
+ }; \
4303
+ template <> \
4304
+ struct apply_impl <ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
4305
+ KERNEL_FLOAT_INLINE static void call (ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
4306
+ __nv_fp8x2_storage_t x; \
4307
+ memcpy (&x, v, 2 * sizeof (FP8_TY)); \
4308
+ __half2_raw y = __nv_cvt_fp8x2_to_halfraw2 (x, FP8_INTERP); \
4309
+ memcpy (result, &y, 2 * sizeof (T)); \
4310
+ } \
4311
+ }; \
4312
+ }
4313
+
4314
+ namespace kernel_float {
4290
4315
KERNEL_FLOAT_FP8_CAST (double )
4291
4316
} // namespace kernel_float
4292
4317
@@ -4297,6 +4322,10 @@ namespace kernel_float {
4297
4322
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e4m3)
4298
4323
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e5m2)
4299
4324
KERNEL_FLOAT_FP8_CAST (__half)
4325
+
4326
+ KERNEL_FLOAT_FP8_CAST2 (__half, __nv_fp8_e4m3, __NV_E4M3)
4327
+ KERNEL_FLOAT_FP8_CAST2 (__half, __nv_fp8_e5m2, __NV_E5M2)
4328
+
4300
4329
} // namespace kernel_float
4301
4330
#endif // KERNEL_FLOAT_FP16_AVAILABLE
4302
4331
@@ -4307,6 +4336,9 @@ namespace kernel_float {
4307
4336
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e4m3)
4308
4337
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e5m2)
4309
4338
KERNEL_FLOAT_FP8_CAST (__nv_bfloat16)
4339
+
4340
+ KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
4341
+ KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
4310
4342
} // namespace kernel_float
4311
4343
#endif // KERNEL_FLOAT_BF16_AVAILABLE
4312
4344
0 commit comments