Skip to content

Commit 1cadb2a

Browse files
committed
issue/342: delete to
1 parent 0ecbe1d commit 1cadb2a

File tree

2 files changed

+9
-25
lines changed

2 files changed

+9
-25
lines changed

src/infiniop/devices/kunlun/kunlun_kernel_common.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
4343
__builtin_memcpy(v, p, len * sizeof(T));
4444
}
4545

46-
/**
47-
* @brief Convert data type. All data is in local memory
48-
* @param v: input value
49-
* @return output value
50-
*/
51-
template <typename Tout, typename Tin>
52-
__device__ inline Tout to(Tin v) {
53-
if constexpr (std::is_same<Tin, half>::value) {
54-
return __half2float(v);
55-
} else if constexpr (std::is_same<Tin, bfloat16_t>::value) {
56-
return __bfloat162float(v);
57-
} else {
58-
return static_cast<Tout>(v);
59-
}
60-
}
61-
6246
/**
6347
* @brief atomicAdd for kunlun xpu
6448
* @param ptr: pointer to shared memory

src/infiniop/ops/random_sample/kunlun/kernel.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
270270

271271
__shared__ Tcompute sum_;
272272
if (core_id() == 0) {
273-
sum_ = to<Tcompute>(0.f);
273+
sum_ = Tcompute(0.f);
274274
}
275275
sync_cluster();
276276

@@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
286286

287287
for (int index = core_id(); index < read_len; index += BLOCK_SIZE) {
288288
if constexpr (std::is_same_v<Tval, half>) {
289-
y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to<float>(max_value)) / temperature));
289+
y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - float(max_value)) / temperature));
290290
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
291-
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to<float>(max_value)) / temperature));
291+
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - float(max_value)) / temperature));
292292
} else if constexpr (std::is_same_v<Tval, float>) {
293293
y_sm[index] = exp((x_sm[index] - max_value) / temperature);
294294
}
@@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
351351
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
352352
for (int index = 0; index < read_len; index++) {
353353
if constexpr (std::is_same_v<Tval, float>) {
354-
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(all_sum);
354+
cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum);
355355
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
356-
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
356+
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
357357
} else if constexpr (std::is_same_v<Tval, half>) {
358-
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
358+
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
359359
}
360360
if (cumsum >= topp) {
361361
end = r * buf_size + index + 1;
@@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
370370
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
371371
for (int index = 0; index < read_len; index++) {
372372
if constexpr (std::is_same_v<Tval, float>) {
373-
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(all_sum);
373+
cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum);
374374
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
375-
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
375+
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
376376
} else if constexpr (std::is_same_v<Tval, half>) {
377-
cumsum += exp((to<float>(values_local[index]) - to<float>(max_value)) / temperature) / to<float>(all_sum);
377+
cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum);
378378
}
379379
if (random_val < cumsum) {
380380
result[0] = indices_global[r * buf_size + index];

0 commit comments

Comments
 (0)