@@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
270270
271271 __shared__ Tcompute sum_;
272272 if (core_id () == 0 ) {
273- sum_ = to< Tcompute> (0 .f );
273+ sum_ = Tcompute (0 .f );
274274 }
275275 sync_cluster ();
276276
@@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs,
286286
287287 for (int index = core_id (); index < read_len; index += BLOCK_SIZE) {
288288 if constexpr (std::is_same_v<Tval, half>) {
289- y_sm[index] = __float2half (exp ((__half2float (x_sm[index]) - to< float > (max_value)) / temperature));
289+ y_sm[index] = __float2half (exp ((__half2float (x_sm[index]) - float (max_value)) / temperature));
290290 } else if constexpr (std::is_same_v<Tval, bfloat16_t >) {
291- y_sm[index] = __float2bfloat16 (exp ((__bfloat162float (x_sm[index]) - to< float > (max_value)) / temperature));
291+ y_sm[index] = __float2bfloat16 (exp ((__bfloat162float (x_sm[index]) - float (max_value)) / temperature));
292292 } else if constexpr (std::is_same_v<Tval, float >) {
293293 y_sm[index] = exp ((x_sm[index] - max_value) / temperature);
294294 }
@@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
351351 GM2LM (values_global + r * buf_size, values_local, read_len * sizeof (Tval));
352352 for (int index = 0 ; index < read_len; index++) {
353353 if constexpr (std::is_same_v<Tval, float >) {
354- cumsum += exp ((values_local[index] - max_value) / temperature) / to< float > (all_sum);
354+ cumsum += exp ((values_local[index] - max_value) / temperature) / float (all_sum);
355355 } else if constexpr (std::is_same_v<Tval, bfloat16_t >) {
356- cumsum += exp ((to< float > (values_local[index]) - to< float > (max_value)) / temperature) / to< float > (all_sum);
356+ cumsum += exp ((float (values_local[index]) - float (max_value)) / temperature) / float (all_sum);
357357 } else if constexpr (std::is_same_v<Tval, half>) {
358- cumsum += exp ((to< float > (values_local[index]) - to< float > (max_value)) / temperature) / to< float > (all_sum);
358+ cumsum += exp ((float (values_local[index]) - float (max_value)) / temperature) / float (all_sum);
359359 }
360360 if (cumsum >= topp) {
361361 end = r * buf_size + index + 1 ;
@@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result,
370370 GM2LM (values_global + r * buf_size, values_local, read_len * sizeof (Tval));
371371 for (int index = 0 ; index < read_len; index++) {
372372 if constexpr (std::is_same_v<Tval, float >) {
373- cumsum += exp ((values_local[index] - max_value) / temperature) / to< float > (all_sum);
373+ cumsum += exp ((values_local[index] - max_value) / temperature) / float (all_sum);
374374 } else if constexpr (std::is_same_v<Tval, bfloat16_t >) {
375- cumsum += exp ((to< float > (values_local[index]) - to< float > (max_value)) / temperature) / to< float > (all_sum);
375+ cumsum += exp ((float (values_local[index]) - float (max_value)) / temperature) / float (all_sum);
376376 } else if constexpr (std::is_same_v<Tval, half>) {
377- cumsum += exp ((to< float > (values_local[index]) - to< float > (max_value)) / temperature) / to< float > (all_sum);
377+ cumsum += exp ((float (values_local[index]) - float (max_value)) / temperature) / float (all_sum);
378378 }
379379 if (random_val < cumsum) {
380380 result[0 ] = indices_global[r * buf_size + index];
0 commit comments