@@ -652,12 +652,20 @@ void invokeQuantizeMatrixColwise(
652
652
T_IN const * const input,
653
653
const int64_t numel,
654
654
const int64_t lda,
655
- const cudaStream_t stream) {
655
+ const c10::cuda::CUDAStream stream) {
656
656
constexpr dim3 grid (1024 );
657
657
const dim3 block (CTA_SIZE);
658
- scaleMatrixColwise<true >
659
- <<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
660
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
658
+ FBGEMM_LAUNCH_KERNEL (
659
+ (scaleMatrixColwise<true , T_OUT, T_S, T_IN>),
660
+ grid,
661
+ block,
662
+ 0 ,
663
+ stream,
664
+ output,
665
+ input_scale,
666
+ input,
667
+ numel,
668
+ lda);
661
669
}
662
670
663
671
template <typename T>
@@ -761,20 +769,24 @@ void invokeComputeScale(
761
769
const int64_t total_elements_per_slice,
762
770
const int64_t * bs,
763
771
const float * scale_ub,
764
- const cudaStream_t stream) {
772
+ const c10::cuda::CUDAStream stream) {
765
773
constexpr dim3 block (1024 );
766
774
constexpr dim3 grid (1024 );
767
775
int64_t numel_scale = numel;
768
776
C10_CUDA_CHECK (cudaMemsetAsync (quant_ptr, 0 , sizeof (T_S), stream));
769
- computeFP8QuantizeScale<<<grid, block, 0 , stream>>> (
777
+ FBGEMM_LAUNCH_KERNEL (
778
+ (computeFP8QuantizeScale<T_S, T_IN>),
779
+ grid,
780
+ block,
781
+ 0 ,
782
+ stream,
770
783
quant_ptr,
771
784
input,
772
785
numel_scale,
773
786
lda,
774
787
total_elements_per_slice,
775
788
bs,
776
789
scale_ub);
777
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
778
790
}
779
791
780
792
at::Tensor get_fp8_per_tensor_scale (
@@ -1145,30 +1157,49 @@ void invokeComputeScalesAndQuantizeMatrix(
1145
1157
rng_engine_inputs =
1146
1158
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state (4 );
1147
1159
1148
- dynamicQuantizeMatrixRowwiseStoc<SCALE>
1149
- <<<grid, block, shmem_size, stream>>> (
1150
- output,
1151
- quant_ptr,
1152
- input,
1153
- numel,
1154
- lda,
1155
- scale_ub,
1156
- rng_engine_inputs);
1157
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1160
+ FBGEMM_LAUNCH_KERNEL (
1161
+ (dynamicQuantizeMatrixRowwiseStoc<SCALE, T_OUT, T_S, T_IN>),
1162
+ grid,
1163
+ block,
1164
+ shmem_size,
1165
+ stream,
1166
+ output,
1167
+ quant_ptr,
1168
+ input,
1169
+ numel,
1170
+ lda,
1171
+ scale_ub,
1172
+ rng_engine_inputs);
1158
1173
} else {
1159
- dynamicQuantizeMatrixRowwise<SCALE><<<grid, block, shmem_size, stream>>> (
1160
- output, quant_ptr, input, numel, lda, scale_ub);
1161
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1174
+ FBGEMM_LAUNCH_KERNEL (
1175
+ (dynamicQuantizeMatrixRowwise<SCALE, T_OUT, T_S, T_IN>),
1176
+ grid,
1177
+ block,
1178
+ shmem_size,
1179
+ stream,
1180
+ output,
1181
+ quant_ptr,
1182
+ input,
1183
+ numel,
1184
+ lda,
1185
+ scale_ub);
1162
1186
}
1163
1187
} else {
1164
1188
dim3 block (CTA_SIZE);
1165
- computeFP8QuantizeScaleRowwise<SCALE>
1166
- <<<grid, block, 0 , stream>>> (quant_ptr, input, numel, lda, scale_ub);
1167
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1189
+ FBGEMM_LAUNCH_KERNEL (
1190
+ (computeFP8QuantizeScaleRowwise<SCALE, T_S, T_IN>),
1191
+ grid,
1192
+ block,
1193
+ 0 ,
1194
+ stream,
1195
+ quant_ptr,
1196
+ input,
1197
+ numel,
1198
+ lda,
1199
+ scale_ub);
1168
1200
invokeQuantizeMatrixRowwise (
1169
1201
output, quant_ptr, input, numel, lda, stochastic_rounding, stream);
1170
1202
}
1171
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1172
1203
}
1173
1204
1174
1205
template <typename T_OUT, typename T_S, typename T_IN>
@@ -1178,7 +1209,7 @@ void invokeComputeScalesAndQuantizeMatrixCol(
1178
1209
const T_IN* input,
1179
1210
const int64_t numel,
1180
1211
const int64_t lda,
1181
- cudaStream_t stream) {
1212
+ c10::cuda::CUDAStream stream) {
1182
1213
dim3 block (CTA_SIZE);
1183
1214
dim3 grid ((lda + CTA_SIZE - 1 ) / CTA_SIZE);
1184
1215
C10_CUDA_CHECK (cudaMemsetAsync (quant_ptr, 0 , lda * sizeof (T_S), stream));
@@ -1598,7 +1629,7 @@ void invokeFP4Quantization(
1598
1629
int32_t * SFOuput,
1599
1630
bool useUE8M0,
1600
1631
int multiProcessorCount,
1601
- cudaStream_t stream) {
1632
+ c10::cuda::CUDAStream stream) {
1602
1633
// Grid, Block size.
1603
1634
// Each thread converts 8 values.
1604
1635
dim3 block (std::min (int (n / ELTS_PER_THREAD), 512 ));
@@ -1636,7 +1667,7 @@ template void invokeFP4Quantization(
1636
1667
int32_t * SFOuput,
1637
1668
bool useUE8M0,
1638
1669
int multiProcessorCount,
1639
- cudaStream_t stream);
1670
+ c10::cuda::CUDAStream stream);
1640
1671
1641
1672
template void invokeFP4Quantization (
1642
1673
int m,
@@ -1647,7 +1678,7 @@ template void invokeFP4Quantization(
1647
1678
int32_t * SFOuput,
1648
1679
bool useUE8M0,
1649
1680
int multiProcessorCount,
1650
- cudaStream_t stream);
1681
+ c10::cuda::CUDAStream stream);
1651
1682
1652
1683
int64_t get_device_attribute (int64_t attribute, int64_t device_id) {
1653
1684
// Return the cached value on subsequent calls
@@ -1887,7 +1918,7 @@ void fp4_fused_amax_quantize(
1887
1918
__nv_bfloat16 const * const x,
1888
1919
const int64_t numel,
1889
1920
const int blocksize,
1890
- const cudaStream_t stream) {
1921
+ const c10::cuda::CUDAStream stream) {
1891
1922
const int blocks_per_cta = 4 ;
1892
1923
1893
1924
const dim3 block (blocksize, blocks_per_cta);
@@ -1938,7 +1969,7 @@ void invokeComputeFP4GlobalAmax(
1938
1969
const int64_t total_elements_per_slice,
1939
1970
const int64_t * bs,
1940
1971
const float * scale_ub,
1941
- const cudaStream_t stream) {
1972
+ const c10::cuda::CUDAStream stream) {
1942
1973
constexpr dim3 block (1024 );
1943
1974
constexpr dim3 grid (1024 );
1944
1975
int64_t numel_scale = numel;
0 commit comments