Skip to content

Commit a0dd77b

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI quantize kernels to FBGEMM_LAUNCH_KERNEL, pt 3 (#4858)
Summary: Pull Request resolved: #4858 - Migrate GenAI quantize kernels to `FBGEMM_LAUNCH_KERNEL`, pt 3 Reviewed By: cthi Differential Revision: D81540250 fbshipit-source-id: c8798a5647fb36ab3b2b92222c6b8771ec44cbf7
1 parent dc0ab6d commit a0dd77b

File tree

1 file changed

+61
-30
lines changed

1 file changed

+61
-30
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -652,12 +652,20 @@ void invokeQuantizeMatrixColwise(
652652
T_IN const* const input,
653653
const int64_t numel,
654654
const int64_t lda,
655-
const cudaStream_t stream) {
655+
const c10::cuda::CUDAStream stream) {
656656
constexpr dim3 grid(1024);
657657
const dim3 block(CTA_SIZE);
658-
scaleMatrixColwise<true>
659-
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
660-
C10_CUDA_KERNEL_LAUNCH_CHECK();
658+
FBGEMM_LAUNCH_KERNEL(
659+
(scaleMatrixColwise<true, T_OUT, T_S, T_IN>),
660+
grid,
661+
block,
662+
0,
663+
stream,
664+
output,
665+
input_scale,
666+
input,
667+
numel,
668+
lda);
661669
}
662670

663671
template <typename T>
@@ -761,20 +769,24 @@ void invokeComputeScale(
761769
const int64_t total_elements_per_slice,
762770
const int64_t* bs,
763771
const float* scale_ub,
764-
const cudaStream_t stream) {
772+
const c10::cuda::CUDAStream stream) {
765773
constexpr dim3 block(1024);
766774
constexpr dim3 grid(1024);
767775
int64_t numel_scale = numel;
768776
C10_CUDA_CHECK(cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream));
769-
computeFP8QuantizeScale<<<grid, block, 0, stream>>>(
777+
FBGEMM_LAUNCH_KERNEL(
778+
(computeFP8QuantizeScale<T_S, T_IN>),
779+
grid,
780+
block,
781+
0,
782+
stream,
770783
quant_ptr,
771784
input,
772785
numel_scale,
773786
lda,
774787
total_elements_per_slice,
775788
bs,
776789
scale_ub);
777-
C10_CUDA_KERNEL_LAUNCH_CHECK();
778790
}
779791

780792
at::Tensor get_fp8_per_tensor_scale(
@@ -1145,30 +1157,49 @@ void invokeComputeScalesAndQuantizeMatrix(
11451157
rng_engine_inputs =
11461158
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(4);
11471159

1148-
dynamicQuantizeMatrixRowwiseStoc<SCALE>
1149-
<<<grid, block, shmem_size, stream>>>(
1150-
output,
1151-
quant_ptr,
1152-
input,
1153-
numel,
1154-
lda,
1155-
scale_ub,
1156-
rng_engine_inputs);
1157-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1160+
FBGEMM_LAUNCH_KERNEL(
1161+
(dynamicQuantizeMatrixRowwiseStoc<SCALE, T_OUT, T_S, T_IN>),
1162+
grid,
1163+
block,
1164+
shmem_size,
1165+
stream,
1166+
output,
1167+
quant_ptr,
1168+
input,
1169+
numel,
1170+
lda,
1171+
scale_ub,
1172+
rng_engine_inputs);
11581173
} else {
1159-
dynamicQuantizeMatrixRowwise<SCALE><<<grid, block, shmem_size, stream>>>(
1160-
output, quant_ptr, input, numel, lda, scale_ub);
1161-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1174+
FBGEMM_LAUNCH_KERNEL(
1175+
(dynamicQuantizeMatrixRowwise<SCALE, T_OUT, T_S, T_IN>),
1176+
grid,
1177+
block,
1178+
shmem_size,
1179+
stream,
1180+
output,
1181+
quant_ptr,
1182+
input,
1183+
numel,
1184+
lda,
1185+
scale_ub);
11621186
}
11631187
} else {
11641188
dim3 block(CTA_SIZE);
1165-
computeFP8QuantizeScaleRowwise<SCALE>
1166-
<<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda, scale_ub);
1167-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1189+
FBGEMM_LAUNCH_KERNEL(
1190+
(computeFP8QuantizeScaleRowwise<SCALE, T_S, T_IN>),
1191+
grid,
1192+
block,
1193+
0,
1194+
stream,
1195+
quant_ptr,
1196+
input,
1197+
numel,
1198+
lda,
1199+
scale_ub);
11681200
invokeQuantizeMatrixRowwise(
11691201
output, quant_ptr, input, numel, lda, stochastic_rounding, stream);
11701202
}
1171-
C10_CUDA_KERNEL_LAUNCH_CHECK();
11721203
}
11731204

11741205
template <typename T_OUT, typename T_S, typename T_IN>
@@ -1178,7 +1209,7 @@ void invokeComputeScalesAndQuantizeMatrixCol(
11781209
const T_IN* input,
11791210
const int64_t numel,
11801211
const int64_t lda,
1181-
cudaStream_t stream) {
1212+
c10::cuda::CUDAStream stream) {
11821213
dim3 block(CTA_SIZE);
11831214
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
11841215
C10_CUDA_CHECK(cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream));
@@ -1598,7 +1629,7 @@ void invokeFP4Quantization(
15981629
int32_t* SFOuput,
15991630
bool useUE8M0,
16001631
int multiProcessorCount,
1601-
cudaStream_t stream) {
1632+
c10::cuda::CUDAStream stream) {
16021633
// Grid, Block size.
16031634
// Each thread converts 8 values.
16041635
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
@@ -1636,7 +1667,7 @@ template void invokeFP4Quantization(
16361667
int32_t* SFOuput,
16371668
bool useUE8M0,
16381669
int multiProcessorCount,
1639-
cudaStream_t stream);
1670+
c10::cuda::CUDAStream stream);
16401671

16411672
template void invokeFP4Quantization(
16421673
int m,
@@ -1647,7 +1678,7 @@ template void invokeFP4Quantization(
16471678
int32_t* SFOuput,
16481679
bool useUE8M0,
16491680
int multiProcessorCount,
1650-
cudaStream_t stream);
1681+
c10::cuda::CUDAStream stream);
16511682

16521683
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
16531684
// Return the cached value on subsequent calls
@@ -1887,7 +1918,7 @@ void fp4_fused_amax_quantize(
18871918
__nv_bfloat16 const* const x,
18881919
const int64_t numel,
18891920
const int blocksize,
1890-
const cudaStream_t stream) {
1921+
const c10::cuda::CUDAStream stream) {
18911922
const int blocks_per_cta = 4;
18921923

18931924
const dim3 block(blocksize, blocks_per_cta);
@@ -1938,7 +1969,7 @@ void invokeComputeFP4GlobalAmax(
19381969
const int64_t total_elements_per_slice,
19391970
const int64_t* bs,
19401971
const float* scale_ub,
1941-
const cudaStream_t stream) {
1972+
const c10::cuda::CUDAStream stream) {
19421973
constexpr dim3 block(1024);
19431974
constexpr dim3 grid(1024);
19441975
int64_t numel_scale = numel;

0 commit comments

Comments
 (0)