@@ -59,7 +59,28 @@ extern "C" __global__ void kernel(
5959}}
6060)~" ;
6161
62- constexpr static const char *SCALAR = R"~(
62+ constexpr static const char *SCALAR_A = R"~(
63+ __device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{
64+ return {1:};
65+ }}
66+
67+ extern "C" __global__ void kernel(
68+ {0:} *__restrict__ y,
69+ {0:} const *__restrict__ s,
70+ {0:} const *__restrict__ v,
71+ size_t n
72+ ) {{
73+ auto num = *s;
74+ for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
75+ step = blockDim.x * gridDim.x;
76+ tid < n;
77+ tid += step) {{
78+ y[tid] = fn(num, v[tid]);
79+ }}
80+ }}
81+ )~" ;
82+
83+ constexpr static const char *SCALAR_B = R"~(
6384__device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{
6485 return {1:};
6586}}
@@ -209,19 +230,17 @@ extern "C" __global__ void kernel(
209230
210231 } else if (auto rank = broadcaster.strides .size () / (broadcaster.inputsCount + 1 ); rank == 1 ) {
211232 static const std::vector<dim_t > S0{0 , 1 , 1 }, S1{1 , 0 , 1 };
212- auto name = fmt::format (" binaryScalar{}" , postfix);
213- auto code = fmt::format (SCALAR, dt_, op_);
214- return [params, h = nvrtc::Handler::compile (name.c_str (), code.c_str (), " kernel" ),
215- // clang-format off
216- scalar = broadcaster.strides == S0 ? 0
217- : broadcaster.strides == S1 ? 1
218- : UNREACHABLEX (int , " Unreachable" )]// clang-format on
233+ auto scalar_a = broadcaster.strides == S0;
234+ auto name = fmt::format (" binaryScalar{}{}" , postfix, scalar_a ? " A" : " B" );
235+ auto code = scalar_a ? fmt::format (SCALAR_A, dt_, op_)
236+ : fmt::format (SCALAR_B, dt_, op_);
237+ return [params, h = nvrtc::Handler::compile (name.c_str (), code.c_str (), " kernel" )]//
219238 (Resources &, void *, void const *const *inputs, void *const *outputs) {
220239 auto c = outputs[0 ];
221- auto s = inputs[scalar ],
222- v = inputs[1 - scalar ];
240+ auto a = inputs[0 ],
241+ b = inputs[1 ];
223242 auto n = params.n ;
224- void *args[]{&c, &v , &s , &n};
243+ void *args[]{&c, &a , &b , &n};
225244 h->launch (params.gridSize , 1 , 1 ,
226245 params.blockSize , 1 , 1 ,
227246 0 , args);
0 commit comments