Skip to content

Commit e952eee

Browse files
authored
[Bugfix] Fix __syncwarp on ROCM (#25996)
1 parent 66bca9b commit e952eee

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

csrc/cache_kernels.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
536536
for (int i = 0; i < VEC_SIZE; i++) {
537537
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
538538
}
539+
#ifndef USE_ROCM
539540
__syncwarp();
541+
#endif
540542

541543
// Reduced amax
542544
for (int mask = 16; mask > 0; mask /= 2) {
@@ -546,7 +548,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
546548
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
547549
#endif
548550
}
551+
#ifndef USE_ROCM
549552
__syncwarp();
553+
#endif
550554
float scale = fmaxf(amax, 1e-4) / 448.0f;
551555
if (use_ue8m0) {
552556
scale = exp2f(ceilf(log2f(scale)));
@@ -1167,4 +1171,4 @@ void indexer_k_quant_and_cache(
11671171

11681172
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
11691173
CALL_INDEXER_K_QUANT_AND_CACHE);
1170-
}
1174+
}

0 commit comments

Comments
 (0)