16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2025-08-12 09:36:07.217735
20
- // git hash: 15a92ee9e96aef3147fdcfc3dcb3bd4ce501d063
19
+ // date: 2025-08-12 13:55:51.042675
20
+ // git hash: 714ca6b5fd63ef3497d80ef018cb9a9460c91391
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
41
41
#endif // __CUDA_ARCH__
42
42
#elif defined(__HIPCC__)
43
43
#define KERNEL_FLOAT_IS_HIP (1 )
44
- #define KERNEL_FLOAT_DEVICE __attribute__ ((always_inline)) __device__
45
- #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) __host__ __device__
44
+ #define KERNEL_FLOAT_DEVICE __attribute__ ((always_inline)) inline __device__
45
+ #define KERNEL_FLOAT_INLINE __attribute__ ((always_inline)) inline __host__ __device__
46
46
47
47
#ifdef __HIP_DEVICE_COMPILE__
48
48
#define KERNEL_FLOAT_IS_DEVICE (1 )
@@ -781,6 +781,37 @@ broadcast_like(const V& input, const R& other) {
781
781
return broadcast (input, vector_extent_type<R> {});
782
782
}
783
783
784
+ namespace detail {
785
+
786
+ template <typename F, typename ... Args>
787
+ struct invoke_impl {
788
+ KERNEL_FLOAT_INLINE static decltype (auto ) call(F fun, Args... args) {
789
+ return std::forward<F>(fun)(std::forward<Args>(args)...);
790
+ }
791
+ };
792
+
793
+ } // namespace detail
794
+
795
+ template <typename F, typename ... Args>
796
+ using result_t = decltype (detail::invoke_impl<decay_t <F>, decay_t <Args>...>::call(
797
+ detail::declval<F>(),
798
+ detail::declval<Args>()...));
799
+
800
+ /* *
801
+ * Invoke the given function `fun` with the arguments `args...`.
802
+ *
803
+ * The main difference between directly calling `fun(args...)`, is that the behavior can be overridden by
804
+ * specializing on `detail::invoke_impl<F, Args...>`.
805
+ *
806
+ * @return The result of `fun(args...)`.
807
+ */
808
+ template <typename F, typename ... Args>
809
+ KERNEL_FLOAT_INLINE result_t <F, Args...> invoke (F fun, const Args&... args) {
810
+ return detail::invoke_impl<decay_t <F>, decay_t <Args>...>::call (
811
+ std::forward<F>(fun),
812
+ std::forward<Args>(args)...);
813
+ }
814
+
784
815
/* *
785
816
* The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all
786
817
* operations are performed without any approximations or optimizations that could potentially alter the precise
@@ -833,13 +864,6 @@ using default_policy = accurate_policy;
833
864
834
865
namespace detail {
835
866
836
- template <typename F, typename ... Args>
837
- struct invoke_impl {
838
- KERNEL_FLOAT_INLINE static auto call (F fun, Args... args) {
839
- return fun (args...);
840
- }
841
- };
842
-
843
867
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
844
868
struct apply_impl ;
845
869
@@ -855,7 +879,7 @@ struct apply_impl<accurate_policy, F, N, Output, Args...> {
855
879
KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
856
880
#pragma unroll
857
881
for (size_t i = 0 ; i < N; i++) {
858
- output[i] = invoke_impl<F, Args...>::call (fun, args[i]...);
882
+ output[i] = detail:: invoke_impl<F, Args...>::call (fun, args[i]...);
859
883
}
860
884
}
861
885
};
@@ -890,10 +914,6 @@ using default_map_impl = map_impl<default_policy, F, N, Output, Args...>;
890
914
891
915
} // namespace detail
892
916
893
- template <typename F, typename ... Args>
894
- using result_t = decltype (
895
- detail::invoke_impl<F, Args...>::call(detail::declval<F>(), detail::declval<Args>()...));
896
-
897
917
template <typename F, typename ... Args>
898
918
using map_type =
899
919
vector<result_t <F, vector_value_type<Args>...>, broadcast_vector_extent_type<Args...>>;
@@ -4407,14 +4427,12 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
4407
4427
#elif KERNEL_FLOAT_IS_HIP
4408
4428
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs (const __hip_bfloat16 a) {
4409
4429
unsigned short int res = __bfloat16_as_ushort (a);
4410
- res &= 0x7FFF ;
4411
- return __ushort_as_bfloat16 ();
4430
+ return __ushort_as_bfloat16 (res & 0x7FFF );
4412
4431
}
4413
4432
4414
4433
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg (const __hip_bfloat16 a) {
4415
4434
unsigned short int res = __bfloat16_as_ushort (a);
4416
- res ^= 0x8000 ;
4417
- return __ushort_as_bfloat16 (res);
4435
+ return __ushort_as_bfloat16 (res ^ 0x8000 );
4418
4436
}
4419
4437
4420
4438
KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2 (const __hip_bfloat162 a) {
0 commit comments