Skip to content

Commit 20788a3

Browse files
committed
fix(kernel): 双目运算不能交换
Signed-off-by: YdrMaster <[email protected]>
1 parent ec39fd7 commit 20788a3

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

src/04kernel/src/kernels/simple_binary/cuda_kernel.cc

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,28 @@ extern "C" __global__ void kernel(
5959
}}
6060
)~";
6161

62-
constexpr static const char *SCALAR = R"~(
62+
constexpr static const char *SCALAR_A = R"~(
63+
__device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{
64+
return {1:};
65+
}}
66+
67+
extern "C" __global__ void kernel(
68+
{0:} *__restrict__ y,
69+
{0:} const *__restrict__ s,
70+
{0:} const *__restrict__ v,
71+
size_t n
72+
) {{
73+
auto num = *s;
74+
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
75+
step = blockDim.x * gridDim.x;
76+
tid < n;
77+
tid += step) {{
78+
y[tid] = fn(num, v[tid]);
79+
}}
80+
}}
81+
)~";
82+
83+
constexpr static const char *SCALAR_B = R"~(
6384
__device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{
6485
return {1:};
6586
}}
@@ -209,19 +230,17 @@ extern "C" __global__ void kernel(
209230

210231
} else if (auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); rank == 1) {
211232
static const std::vector<dim_t> S0{0, 1, 1}, S1{1, 0, 1};
212-
auto name = fmt::format("binaryScalar{}", postfix);
213-
auto code = fmt::format(SCALAR, dt_, op_);
214-
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
215-
// clang-format off
216-
scalar = broadcaster.strides == S0 ? 0
217-
: broadcaster.strides == S1 ? 1
218-
: UNREACHABLEX(int, "Unreachable")]// clang-format on
233+
auto scalar_a = broadcaster.strides == S0;
234+
auto name = fmt::format("binaryScalar{}{}", postfix, scalar_a ? "A" : "B");
235+
auto code = scalar_a ? fmt::format(SCALAR_A, dt_, op_)
236+
: fmt::format(SCALAR_B, dt_, op_);
237+
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]//
219238
(Resources &, void *, void const *const *inputs, void *const *outputs) {
220239
auto c = outputs[0];
221-
auto s = inputs[scalar],
222-
v = inputs[1 - scalar];
240+
auto a = inputs[0],
241+
b = inputs[1];
223242
auto n = params.n;
224-
void *args[]{&c, &v, &s, &n};
243+
void *args[]{&c, &a, &b, &n};
225244
h->launch(params.gridSize, 1, 1,
226245
params.blockSize, 1, 1,
227246
0, args);

src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,14 @@ TEST(kernel, BinaryCudaFmodF32) {
103103
}
104104

105105
TEST(kernel, BinaryCudaBroadcast) {
106-
testBinaryCuda<DataType::I8>(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6});
106+
testBinaryCuda<DataType::F32>(SimpleBinaryType::Sub,
107+
Shape{1, 2, 3, 4, 5, 6},
108+
Shape{},
109+
Shape{1, 2, 3, 4, 5, 6});
110+
testBinaryCuda<DataType::F32>(SimpleBinaryType::Div,
111+
Shape{},
112+
Shape{1, 2, 3, 4, 5, 6},
113+
Shape{1, 2, 3, 4, 5, 6});
107114
}
108115

109116
#endif

0 commit comments

Comments
 (0)