Skip to content

Commit ff3f808

Browse files
committed
Add invoke function
1 parent 714ca6b commit ff3f808

File tree

4 files changed

+73
-35
lines changed

4 files changed

+73
-35
lines changed

.github/workflows/cmake.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ jobs:
2020
with:
2121
rocm-version: "6.3.0"
2222

23-
build-cuda-13-0:
23+
build-cuda-12-7:
2424
if: github.ref_name == 'main'
2525
needs: build-cuda
2626
uses: ./.github/workflows/cmake-run-cuda.yml
2727
with:
28-
cuda-version: "13.0.0"
28+
cuda-version: "12.7.0"
2929

3030
build-cuda-12-6:
3131
if: github.ref_name == 'main'

include/kernel_float/apply.h

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,37 @@ broadcast_like(const V& input, const R& other) {
116116
return broadcast(input, vector_extent_type<R> {});
117117
}
118118

119+
namespace detail {
120+
121+
template<typename F, typename... Args>
122+
struct invoke_impl {
123+
KERNEL_FLOAT_INLINE static decltype(auto) call(F fun, Args... args) {
124+
return std::forward<F>(fun)(std::forward<Args>(args)...);
125+
}
126+
};
127+
128+
} // namespace detail
129+
130+
template<typename F, typename... Args>
131+
using result_t = decltype(detail::invoke_impl<decay_t<F>, decay_t<Args>...>::call(
132+
detail::declval<F>(),
133+
detail::declval<Args>()...));
134+
135+
/**
136+
* Invoke the given function `fun` with the arguments `args...`.
137+
*
138+
* The main difference between directly calling `fun(args...)`, is that the behavior can be overridden by
139+
* specializing on `detail::invoke_impl<F, Args...>`.
140+
*
141+
* @return The result of `fun(args...)`.
142+
*/
143+
template<typename F, typename... Args>
144+
KERNEL_FLOAT_INLINE result_t<F, Args...> invoke(F fun, const Args&... args) {
145+
return detail::invoke_impl<decay_t<F>, decay_t<Args>...>::call(
146+
std::forward<F>(fun),
147+
std::forward<Args>(args)...);
148+
}
149+
119150
/**
120151
* The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all
121152
* operations are performed without any approximations or optimizations that could potentially alter the precise
@@ -168,13 +199,6 @@ using default_policy = accurate_policy;
168199

169200
namespace detail {
170201

171-
template<typename F, typename... Args>
172-
struct invoke_impl {
173-
KERNEL_FLOAT_INLINE static auto call(F fun, Args... args) {
174-
return fun(args...);
175-
}
176-
};
177-
178202
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
179203
struct apply_impl;
180204

@@ -190,7 +214,7 @@ struct apply_impl<accurate_policy, F, N, Output, Args...> {
190214
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
191215
#pragma unroll
192216
for (size_t i = 0; i < N; i++) {
193-
output[i] = invoke_impl<F, Args...>::call(fun, args[i]...);
217+
output[i] = detail::invoke_impl<F, Args...>::call(fun, args[i]...);
194218
}
195219
}
196220
};
@@ -225,10 +249,6 @@ using default_map_impl = map_impl<default_policy, F, N, Output, Args...>;
225249

226250
} // namespace detail
227251

228-
template<typename F, typename... Args>
229-
using result_t = decltype(
230-
detail::invoke_impl<F, Args...>::call(detail::declval<F>(), detail::declval<Args>()...));
231-
232252
template<typename F, typename... Args>
233253
using map_type =
234254
vector<result_t<F, vector_value_type<Args>...>, broadcast_vector_extent_type<Args...>>;

single_include/kernel_float.h

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2025-08-12 09:36:07.217735
20-
// git hash: 15a92ee9e96aef3147fdcfc3dcb3bd4ce501d063
19+
// date: 2025-08-12 13:55:51.042675
20+
// git hash: 714ca6b5fd63ef3497d80ef018cb9a9460c91391
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -41,8 +41,8 @@
4141
#endif // __CUDA_ARCH__
4242
#elif defined(__HIPCC__)
4343
#define KERNEL_FLOAT_IS_HIP (1)
44-
#define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__
45-
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
44+
#define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) inline __device__
45+
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) inline __host__ __device__
4646

4747
#ifdef __HIP_DEVICE_COMPILE__
4848
#define KERNEL_FLOAT_IS_DEVICE (1)
@@ -781,6 +781,37 @@ broadcast_like(const V& input, const R& other) {
781781
return broadcast(input, vector_extent_type<R> {});
782782
}
783783

784+
namespace detail {
785+
786+
template<typename F, typename... Args>
787+
struct invoke_impl {
788+
KERNEL_FLOAT_INLINE static decltype(auto) call(F fun, Args... args) {
789+
return std::forward<F>(fun)(std::forward<Args>(args)...);
790+
}
791+
};
792+
793+
} // namespace detail
794+
795+
template<typename F, typename... Args>
796+
using result_t = decltype(detail::invoke_impl<decay_t<F>, decay_t<Args>...>::call(
797+
detail::declval<F>(),
798+
detail::declval<Args>()...));
799+
800+
/**
801+
* Invoke the given function `fun` with the arguments `args...`.
802+
*
803+
* The main difference between directly calling `fun(args...)`, is that the behavior can be overridden by
804+
* specializing on `detail::invoke_impl<F, Args...>`.
805+
*
806+
* @return The result of `fun(args...)`.
807+
*/
808+
template<typename F, typename... Args>
809+
KERNEL_FLOAT_INLINE result_t<F, Args...> invoke(F fun, const Args&... args) {
810+
return detail::invoke_impl<decay_t<F>, decay_t<Args>...>::call(
811+
std::forward<F>(fun),
812+
std::forward<Args>(args)...);
813+
}
814+
784815
/**
785816
* The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all
786817
* operations are performed without any approximations or optimizations that could potentially alter the precise
@@ -833,13 +864,6 @@ using default_policy = accurate_policy;
833864

834865
namespace detail {
835866

836-
template<typename F, typename... Args>
837-
struct invoke_impl {
838-
KERNEL_FLOAT_INLINE static auto call(F fun, Args... args) {
839-
return fun(args...);
840-
}
841-
};
842-
843867
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
844868
struct apply_impl;
845869

@@ -855,7 +879,7 @@ struct apply_impl<accurate_policy, F, N, Output, Args...> {
855879
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
856880
#pragma unroll
857881
for (size_t i = 0; i < N; i++) {
858-
output[i] = invoke_impl<F, Args...>::call(fun, args[i]...);
882+
output[i] = detail::invoke_impl<F, Args...>::call(fun, args[i]...);
859883
}
860884
}
861885
};
@@ -890,10 +914,6 @@ using default_map_impl = map_impl<default_policy, F, N, Output, Args...>;
890914

891915
} // namespace detail
892916

893-
template<typename F, typename... Args>
894-
using result_t = decltype(
895-
detail::invoke_impl<F, Args...>::call(detail::declval<F>(), detail::declval<Args>()...));
896-
897917
template<typename F, typename... Args>
898918
using map_type =
899919
vector<result_t<F, vector_value_type<Args>...>, broadcast_vector_extent_type<Args...>>;
@@ -4407,14 +4427,12 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
44074427
#elif KERNEL_FLOAT_IS_HIP
44084428
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) {
44094429
unsigned short int res = __bfloat16_as_ushort(a);
4410-
res &= 0x7FFF;
4411-
return __ushort_as_bfloat16();
4430+
return __ushort_as_bfloat16(res & 0x7FFF);
44124431
}
44134432

44144433
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) {
44154434
unsigned short int res = __bfloat16_as_ushort(a);
4416-
res ^= 0x8000;
4417-
return __ushort_as_bfloat16(res);
4435+
return __ushort_as_bfloat16(res ^ 0x8000);
44184436
}
44194437

44204438
KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) {

tests/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include <stdint.h>
43
#include <math.h>
4+
#include <stdint.h>
55
#include <tgmath.h>
66

77
#include "catch2/catch_all.hpp"

0 commit comments

Comments
 (0)