Skip to content

Commit a1d116e

Browse files
flaviotruzzifacebook-github-bot
authored andcommitted
- Reland D75563906
Summary: Reland D75563906 that was backed out, with fixes. The problem was the grid was not being big enough given the config. Further ensured vectorization, which allows 1.4Tb/s. Test: ``` Running correctness tests... Testing correctness for dtype: torch.float32 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) Testing correctness for dtype: torch.float16 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) Testing correctness for dtype: torch.bfloat16 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) All correctness tests passed! 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.68it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:27<00:00, 1.90it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:28<00:00, 1.86it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:27<00:00, 29.29s/it] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.69it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.70it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:30<00:00, 1.73it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:33<00:00, 31.02s/it] [------------------------------------ Not vectorized ------------------------------------] | fwd, torch.float32 | fwd, torch.float16 | fwd, torch.bfloat16 1 threads: ------------------------------------------------------------------------------- 32, 32 | 29.2 | 20.6 | 20.3 64, 32 | 29.3 | 20.5 | 21.2 256, 128 | 23.5 | 20.4 | 20.0 512, 1024 | 22.6 | 20.6 | 20.0 1024, 2048 | 52.9 | 40.9 | 40.9 2048, 2048 | 88.8 | 58.1 | 58.1 4096, 16384 | 1327.5 | 918.2 | 921.5 70000, 64 | 94.0 | 67.1 | 67.1 131072, 512 | 908.2 | 710.9 | 716.4 1000, 520 | 22.4 | 20.1 | 20.6 4005, 4005 | 339.3 | 237.7 | 238.7 10000, 1000 | 207.8 | 129.9 | 130.7 1024, 10000 | 357.8 | 272.0 | 277.6 8192, 4096 | 669.1 | 463.8 | 465.1 10000, 10000 | 1934.2 | 1283.3 | 1287.1 3072, 10000 | 676.7 | 484.2 | 484.9 6144, 10000 | 1165.3 | 767.3 | 770.7 1024, 20000 | 703.6 | 577.5 | 578.6 512, 1536 | 30.6 | 25.8 | 25.6 512, 6144 | 99.0 | 80.6 | 80.6 512, 10240 | 205.4 | 128.6 | 128.5 1000, 1000 | 30.2 | 24.0 | 23.8 2000, 2000 | 81.0 | 56.4 | 56.7 10240, 10240 | 1963.8 | 1323.1 | 1326.9 384, 128 | 20.7 | 20.5 | 20.0 2048, 1024 | 45.3 | 33.0 | 33.0 267, 513 | 20.7 | 20.1 | 19.8 67, 123479 | 2322.7 | 1230.8 | 1313.2 1024, 123479 | 4244.2 | 3485.9 | 3491.3 1234154, 512 | 6838.8 | 5890.0 | 5956.4 2048, 66679 | 3304.3 | 2477.3 | 2483.4 200, 256 | 20.5 | 19.7 | 19.7 1000, 256 | 20.5 | 19.9 | 19.6 6000, 256 | 32.9 | 23.7 | 23.8 6272, 256 | 34.0 | 24.8 | 24.9 200, 512 | 21.0 | 19.7 | 20.2 1000, 512 | 20.3 | 20.2 | 19.8 6000, 512 | 56.4 | 38.4 | 38.6 6272, 512 | 59.0 | 40.3 | 40.9 200, 1024 | 20.7 | 19.7 | 20.0 1000, 1024 | 30.2 | 24.1 | 24.0 6000, 1024 | 118.6 | 67.4 | 67.9 6272, 1024 | 123.8 | 73.3 | 72.6 200, 2048 | 28.9 | 26.3 | 26.2 1000, 2048 | 52.3 | 40.5 | 40.2 6000, 2048 | 238.6 | 159.3 | 159.2 6272, 2048 | 246.8 | 165.6 | 165.7 200, 3072 | 39.9 | 36.3 | 36.2 1000, 3072 | 75.0 | 56.8 | 56.5 6000, 3072 | 352.0 | 240.9 | 241.9 6272, 3072 | 365.4 | 249.1 | 250.1 3000000, 512 | 16525.9 | 14259.5 | 14365.3 Times are in microseconds (us). [-------------------------------------- Vectorized --------------------------------------] | fwd, torch.float32 | fwd, torch.float16 | fwd, torch.bfloat16 1 threads: ------------------------------------------------------------------------------- 32, 32 | 19.4 | 19.6 | 19.6 64, 32 | 19.6 | 19.5 | 20.1 256, 128 | 19.4 | 19.6 | 20.2 512, 1024 | 19.3 | 19.9 | 20.0 1024, 2048 | 30.2 | 29.4 | 29.3 2048, 2048 | 42.4 | 35.1 | 35.0 4096, 16384 | 613.5 | 562.1 | 564.2 70000, 64 | 50.2 | 50.1 | 50.2 131072, 512 | 548.1 | 467.3 | 470.2 1000, 520 | 19.7 | 20.3 | 19.8 4005, 4005 | 206.0 | 240.6 | 240.6 10000, 1000 | 98.3 | 75.4 | 75.1 1024, 10000 | 250.9 | 200.9 | 208.1 8192, 4096 | 317.8 | 285.0 | 286.0 10000, 10000 | 857.6 | 742.6 | 743.9 3072, 10000 | 347.1 | 320.1 | 320.0 6144, 10000 | 501.9 | 433.7 | 435.7 1024, 20000 | 486.9 | 471.7 | 471.3 512, 1536 | 21.9 | 21.7 | 21.6 512, 6144 | 64.8 | 62.7 | 62.9 512, 10240 | 139.6 | 100.2 | 100.2 1000, 1000 | 19.3 | 19.8 | 20.5 2000, 2000 | 36.5 | 34.3 | 34.3 10240, 10240 | 849.1 | 765.9 | 767.3 384, 128 | 20.0 | 20.0 | 20.1 2048, 1024 | 22.4 | 21.5 | 21.6 267, 513 | 20.6 | 20.2 | 19.7 67, 123479 | 2354.0 | 998.4 | 996.2 1024, 123479 | 3134.1 | 3418.3 | 3427.5 1234154, 512 | 5083.0 | 4251.9 | 4278.5 2048, 66679 | 2052.4 | 2323.9 | 2325.0 200, 256 | 19.7 | 19.7 | 19.9 1000, 256 | 20.0 | 19.5 | 19.6 6000, 256 | 20.2 | 19.4 | 19.7 6272, 256 | 19.9 | 19.7 | 20.2 200, 512 | 19.7 | 19.9 | 19.5 1000, 512 | 20.1 | 19.8 | 19.6 6000, 512 | 23.1 | 20.9 | 21.3 6272, 512 | 24.0 | 22.4 | 22.9 200, 1024 | 19.4 | 19.8 | 19.7 1000, 1024 | 19.7 | 19.5 | 20.1 6000, 1024 | 51.9 | 34.1 | 34.6 6272, 1024 | 55.5 | 39.0 | 37.9 200, 2048 | 24.5 | 24.4 | 24.4 1000, 2048 | 30.3 | 29.1 | 28.9 6000, 2048 | 109.2 | 96.5 | 96.9 6272, 2048 | 111.4 | 100.1 | 100.3 200, 3072 | 33.0 | 33.5 | 33.4 1000, 3072 | 41.4 | 39.5 | 39.5 6000, 3072 | 156.3 | 146.5 | 146.9 6272, 3072 | 160.9 | 150.2 | 151.0 3000000, 512 | 12385.8 | 10336.7 | 10366.0 Times are in microseconds (us). ``` Reviewed By: q10 Differential Revision: D77698275
1 parent a0dd77b commit a1d116e

File tree

4 files changed

+191
-37
lines changed

4 files changed

+191
-37
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,6 @@ gpu_cpp_library(
184184
fbgemm_gpu_tbe_cache
185185
fbgemm_gpu_tbe_optimizers
186186
fbgemm_gpu_tbe_utils
187+
fbgemm_gpu_config
187188
DESTINATION
188189
fbgemm_gpu)

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def foo():
6060
# Enable bounds_check_indices_v2
6161
BOUNDS_CHECK_INDICES_V2 = auto()
6262

63+
# disable fp8 quant vectorization
64+
DISABLE_FP8_QUANT_VECTORIZATION = auto()
65+
6366
# Enable TBE input parameters extraction
6467
TBE_REPORT_INPUT_PARAMS = auto()
6568

fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ namespace fbgemm_gpu::config {
5555
/// UI.
5656
///
5757
/// For OSS: The environment variable will be evaluated as f"FBGEMM_{ENUM}"
58-
#define ENUMERATE_ALL_FEATURE_FLAGS \
59-
X(TBE_V2) \
60-
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
61-
X(TBE_ANNOTATE_KINETO_TRACE) \
62-
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
63-
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
64-
X(BOUNDS_CHECK_INDICES_V2) \
58+
#define ENUMERATE_ALL_FEATURE_FLAGS \
59+
X(TBE_V2) \
60+
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
61+
X(TBE_ANNOTATE_KINETO_TRACE) \
62+
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
63+
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
64+
X(BOUNDS_CHECK_INDICES_V2) \
65+
X(DISABLE_FP8_QUANT_VECTORIZATION) \
6566
X(TBE_REPORT_INPUT_PARAMS)
6667
// X(EXAMPLE_FEATURE_FLAG)
6768

fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu

Lines changed: 179 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include "common.cuh"
10+
#include "fbgemm_gpu/config/feature_gates.h"
1011

1112
using Tensor = at::Tensor;
1213

@@ -157,6 +158,125 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
157158
}
158159
}
159160

161+
template <typename scalar_t>
162+
struct VectorSizeTraits {
163+
// Default to 4 elements for most types (16 bytes for float)
164+
static constexpr int value = 4;
165+
};
166+
167+
// Specialization for half (float16)
168+
template <>
169+
struct VectorSizeTraits<c10::Half> {
170+
// 8 elements for half precision (16 bytes total)
171+
static constexpr int value = 8;
172+
};
173+
174+
// Specialization for __nv_bfloat16
175+
template <>
176+
struct VectorSizeTraits<c10::BFloat16> {
177+
// 8 elements for bfloat16 precision (16 bytes total)
178+
static constexpr int value = 8;
179+
};
180+
181+
// aligned vector generates vectorized load/store on CUDA (copy-pasted from
182+
// MemoryAccess.cuh)
183+
template <typename scalar_t, int vec_size = VectorSizeTraits<scalar_t>::value>
184+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
185+
scalar_t val[vec_size];
186+
};
187+
188+
template <typename input_t>
189+
#ifndef USE_ROCM
190+
__global__ __attribute__((maxrregcount(32))) inline void
191+
#else
192+
__global__ inline void
193+
#endif
194+
_compute_FP8_quantize_cuda_vectorized_kernel(
195+
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
196+
const int64_t nrows,
197+
const int64_t ncols,
198+
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output,
199+
const bool forward) {
200+
// Calculate global row index with 2D thread blocks
201+
const int64_t gx = blockIdx.x * blockDim.x + threadIdx.x;
202+
const int64_t thread_idx = blockIdx.y * blockDim.y + threadIdx.y;
203+
static constexpr int vec_size = VectorSizeTraits<input_t>::value;
204+
// Early return if row is out of bounds
205+
if (gx >= nrows || (thread_idx * vec_size) >= ncols) {
206+
return;
207+
}
208+
209+
int ebit = forward ? 4 : 5;
210+
int bias = forward ? 15 : 31;
211+
float max_pos = forward ? 0.9375 : 0.875;
212+
213+
// Calculate output width
214+
const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4;
215+
const auto output_columns = ncols_aligned + 2 * sizeof(float);
216+
217+
// Calculate base offsets for the current row
218+
const int64_t input_row_offset = gx * ncols;
219+
const int64_t output_row_offset = gx * output_columns;
220+
221+
// Calculate the position where the scale values are stored
222+
const int64_t scale_offset = output_row_offset + ncols_aligned;
223+
const float scale_value = reinterpret_cast<float*>(&output[scale_offset])[0];
224+
225+
const int64_t vector_blocks = ncols / vec_size;
226+
227+
using vec_t = aligned_vector<input_t, vec_size>;
228+
using vec_i = aligned_vector<uint8_t, vec_size>;
229+
230+
const int64_t col_idx = thread_idx * vec_size;
231+
232+
// The if else here garantee the kernel works for aligned/misaligned
233+
// cases. When ncols is not multiple of vec_size, then we can't dereference
234+
// the pointer, and we access one by one, this trigger multiple trips to
235+
// global memory, but is still faster than the original kernel.
236+
if ((col_idx + (vec_size - 1) < ncols) && ((ncols % vec_size) == 0)) {
237+
// Load vec_size elements - handle both aligned and unaligned cases
238+
// correctly
239+
const vec_t input_row =
240+
*reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]);
241+
242+
vec_i* output_row =
243+
reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]);
244+
245+
// // Create temporary vector to enable vectorized store
246+
vec_i temp_output;
247+
#pragma unroll
248+
for (int i = 0; i < vec_size; ++i) {
249+
temp_output.val[i] = float_to_hfp8(
250+
to_float(input_row.val[i]) * scale_value, ebit, bias, max_pos);
251+
}
252+
*output_row = temp_output;
253+
} else if ((col_idx + (vec_size - 1) < ncols)) {
254+
// correctly
255+
const vec_t* input_row =
256+
reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]);
257+
258+
vec_i* output_row =
259+
reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]);
260+
#pragma unroll
261+
for (int i = 0; i < vec_size; ++i) {
262+
output_row->val[i] = float_to_hfp8(
263+
to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos);
264+
}
265+
}
266+
267+
// 2. Process any remaining elements (less than vec_size) with scalar
268+
// operations
269+
const int64_t remaining_start = vector_blocks * vec_size;
270+
for (int64_t col = remaining_start + threadIdx.y; col < ncols;
271+
col += blockDim.y) {
272+
output[output_row_offset + col] = float_to_hfp8(
273+
to_float(input[input_row_offset + col]) * scale_value,
274+
ebit,
275+
bias,
276+
max_pos);
277+
}
278+
}
279+
160280
template <typename output_t>
161281
__global__ inline void _FP8rowwise_to_float_cuda_kernel(
162282
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
@@ -247,13 +367,6 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
247367
forward);
248368
});
249369
} else {
250-
// range_tensor is used to store the range for each embedding row.
251-
// We save max_pos/max_val(rowwise) as row scale to quantize
252-
// unlike INT8, FP8 does not have zero shift
253-
// This will guarantee the numerical match but bring some perf
254-
// regression.
255-
auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat));
256-
257370
{
258371
// we need a blockDim.x that is a power of 2 no larger than the warp size
259372
// of 32
@@ -289,27 +402,63 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
289402
}
290403

291404
{
292-
const int blockDim_x =
293-
std::min(ncols, static_cast<int64_t>(threads_per_block));
294-
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
295-
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
296-
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
297-
dim3 gridDim(gridDim_x, gridDim_y);
298-
299-
FBGEMM_DISPATCH_FLOATING_TYPES(
300-
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
301-
FBGEMM_LAUNCH_KERNEL(
302-
(_compute_FP8_quantize_cuda_kernel<scalar_t>),
303-
gridDim,
304-
blockDim,
305-
0,
306-
at::cuda::getCurrentCUDAStream(),
307-
PTA_B(input_1D, scalar_t, 1, 64),
308-
nrows,
309-
ncols,
310-
PTA_B(output_1D, uint8_t, 1, 64),
311-
forward);
312-
});
405+
const uintptr_t addr = reinterpret_cast<uintptr_t>(&input);
406+
407+
const static bool use_vectorization =
408+
((addr % 16) == 0) &&
409+
!config::is_feature_enabled(
410+
config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION);
411+
412+
const constexpr int vec_size = VectorSizeTraits<input_t>::value;
413+
if (use_vectorization) {
414+
const int block_y = 64;
415+
const int blockDim_y = ncols > vec_size ? block_y : 1;
416+
417+
dim3 blockDim(threads_per_block / blockDim_y, blockDim_y);
418+
const auto gridDim_x = cuda_calc_xblock_count(nrows, blockDim.x);
419+
const auto gridDim_y = cuda_calc_block_count(
420+
(ncols + vec_size - 1) / vec_size, blockDim.y);
421+
dim3 gridDim(gridDim_x, gridDim_y);
422+
423+
FBGEMM_DISPATCH_FLOATING_TYPES(
424+
input.scalar_type(),
425+
"_compute_FP8_quantize_cuda_vectorized_kernel",
426+
[&] {
427+
FBGEMM_LAUNCH_KERNEL(
428+
(_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>),
429+
gridDim,
430+
blockDim,
431+
0,
432+
at::cuda::getCurrentCUDAStream(),
433+
PTA_B(input_1D, scalar_t, 1, 64),
434+
nrows,
435+
ncols,
436+
PTA_B(output_1D, uint8_t, 1, 64),
437+
forward);
438+
});
439+
} else {
440+
const int blockDim_x =
441+
std::min(ncols, static_cast<int64_t>(threads_per_block));
442+
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
443+
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
444+
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
445+
dim3 gridDim(gridDim_x, gridDim_y);
446+
447+
FBGEMM_DISPATCH_FLOATING_TYPES(
448+
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
449+
FBGEMM_LAUNCH_KERNEL(
450+
(_compute_FP8_quantize_cuda_kernel<scalar_t>),
451+
gridDim,
452+
blockDim,
453+
0,
454+
at::cuda::getCurrentCUDAStream(),
455+
PTA_B(input_1D, scalar_t, 1, 64),
456+
nrows,
457+
ncols,
458+
PTA_B(output_1D, uint8_t, 1, 64),
459+
forward);
460+
});
461+
}
313462
}
314463
}
315464

@@ -358,8 +507,8 @@ Tensor _FP8rowwise_to_float_gpu_t(
358507
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
359508
// data residing in global memory compiles to a single global memory
360509
// instruction if and only if the size of the data type is 1, 2, 4, 8, or 16
361-
// bytes and the data is naturally aligned (i.e., its address is a multiple of
362-
// that size).
510+
// bytes and the data is naturally aligned (i.e., its address is a multiple
511+
// of that size).
363512
auto output_dims = input_sizes.vec();
364513
output_dims[last_dim] = output_columns;
365514
const auto output_sdtype = static_cast<SparseType>(output_dtype);

0 commit comments

Comments
 (0)