|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include "common.cuh"
|
| 10 | +#include "fbgemm_gpu/config/feature_gates.h" |
10 | 11 |
|
11 | 12 | using Tensor = at::Tensor;
|
12 | 13 |
|
@@ -157,6 +158,125 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
|
157 | 158 | }
|
158 | 159 | }
|
159 | 160 |
|
| 161 | +template <typename scalar_t> |
| 162 | +struct VectorSizeTraits { |
| 163 | + // Default to 4 elements for most types (16 bytes for float) |
| 164 | + static constexpr int value = 4; |
| 165 | +}; |
| 166 | + |
| 167 | +// Specialization for half (float16) |
| 168 | +template <> |
| 169 | +struct VectorSizeTraits<c10::Half> { |
| 170 | + // 8 elements for half precision (16 bytes total) |
| 171 | + static constexpr int value = 8; |
| 172 | +}; |
| 173 | + |
| 174 | +// Specialization for __nv_bfloat16 |
| 175 | +template <> |
| 176 | +struct VectorSizeTraits<c10::BFloat16> { |
| 177 | + // 8 elements for bfloat16 precision (16 bytes total) |
| 178 | + static constexpr int value = 8; |
| 179 | +}; |
| 180 | + |
| 181 | +// aligned vector generates vectorized load/store on CUDA (copy-pasted from |
| 182 | +// MemoryAccess.cuh) |
| 183 | +template <typename scalar_t, int vec_size = VectorSizeTraits<scalar_t>::value> |
| 184 | +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { |
| 185 | + scalar_t val[vec_size]; |
| 186 | +}; |
| 187 | + |
| 188 | +template <typename input_t> |
| 189 | +#ifndef USE_ROCM |
| 190 | +__global__ __attribute__((maxrregcount(32))) inline void |
| 191 | +#else |
| 192 | +__global__ inline void |
| 193 | +#endif |
| 194 | +_compute_FP8_quantize_cuda_vectorized_kernel( |
| 195 | + const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input, |
| 196 | + const int64_t nrows, |
| 197 | + const int64_t ncols, |
| 198 | + pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output, |
| 199 | + const bool forward) { |
| 200 | + // Calculate global row index with 2D thread blocks |
| 201 | + const int64_t gx = blockIdx.x * blockDim.x + threadIdx.x; |
| 202 | + const int64_t thread_idx = blockIdx.y * blockDim.y + threadIdx.y; |
| 203 | + static constexpr int vec_size = VectorSizeTraits<input_t>::value; |
| 204 | + // Early return if row is out of bounds |
| 205 | + if (gx >= nrows || (thread_idx * vec_size) >= ncols) { |
| 206 | + return; |
| 207 | + } |
| 208 | + |
| 209 | + int ebit = forward ? 4 : 5; |
| 210 | + int bias = forward ? 15 : 31; |
| 211 | + float max_pos = forward ? 0.9375 : 0.875; |
| 212 | + |
| 213 | + // Calculate output width |
| 214 | + const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4; |
| 215 | + const auto output_columns = ncols_aligned + 2 * sizeof(float); |
| 216 | + |
| 217 | + // Calculate base offsets for the current row |
| 218 | + const int64_t input_row_offset = gx * ncols; |
| 219 | + const int64_t output_row_offset = gx * output_columns; |
| 220 | + |
| 221 | + // Calculate the position where the scale values are stored |
| 222 | + const int64_t scale_offset = output_row_offset + ncols_aligned; |
| 223 | + const float scale_value = reinterpret_cast<float*>(&output[scale_offset])[0]; |
| 224 | + |
| 225 | + const int64_t vector_blocks = ncols / vec_size; |
| 226 | + |
| 227 | + using vec_t = aligned_vector<input_t, vec_size>; |
| 228 | + using vec_i = aligned_vector<uint8_t, vec_size>; |
| 229 | + |
| 230 | + const int64_t col_idx = thread_idx * vec_size; |
| 231 | + |
| 232 | + // The if else here garantee the kernel works for aligned/misaligned |
| 233 | + // cases. When ncols is not multiple of vec_size, then we can't dereference |
| 234 | + // the pointer, and we access one by one, this trigger multiple trips to |
| 235 | + // global memory, but is still faster than the original kernel. |
| 236 | + if ((col_idx + (vec_size - 1) < ncols) && ((ncols % vec_size) == 0)) { |
| 237 | + // Load vec_size elements - handle both aligned and unaligned cases |
| 238 | + // correctly |
| 239 | + const vec_t input_row = |
| 240 | + *reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]); |
| 241 | + |
| 242 | + vec_i* output_row = |
| 243 | + reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]); |
| 244 | + |
| 245 | + // // Create temporary vector to enable vectorized store |
| 246 | + vec_i temp_output; |
| 247 | +#pragma unroll |
| 248 | + for (int i = 0; i < vec_size; ++i) { |
| 249 | + temp_output.val[i] = float_to_hfp8( |
| 250 | + to_float(input_row.val[i]) * scale_value, ebit, bias, max_pos); |
| 251 | + } |
| 252 | + *output_row = temp_output; |
| 253 | + } else if ((col_idx + (vec_size - 1) < ncols)) { |
| 254 | + // correctly |
| 255 | + const vec_t* input_row = |
| 256 | + reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]); |
| 257 | + |
| 258 | + vec_i* output_row = |
| 259 | + reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]); |
| 260 | +#pragma unroll |
| 261 | + for (int i = 0; i < vec_size; ++i) { |
| 262 | + output_row->val[i] = float_to_hfp8( |
| 263 | + to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos); |
| 264 | + } |
| 265 | + } |
| 266 | + |
| 267 | + // 2. Process any remaining elements (less than vec_size) with scalar |
| 268 | + // operations |
| 269 | + const int64_t remaining_start = vector_blocks * vec_size; |
| 270 | + for (int64_t col = remaining_start + threadIdx.y; col < ncols; |
| 271 | + col += blockDim.y) { |
| 272 | + output[output_row_offset + col] = float_to_hfp8( |
| 273 | + to_float(input[input_row_offset + col]) * scale_value, |
| 274 | + ebit, |
| 275 | + bias, |
| 276 | + max_pos); |
| 277 | + } |
| 278 | +} |
| 279 | + |
160 | 280 | template <typename output_t>
|
161 | 281 | __global__ inline void _FP8rowwise_to_float_cuda_kernel(
|
162 | 282 | pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
|
@@ -247,13 +367,6 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
|
247 | 367 | forward);
|
248 | 368 | });
|
249 | 369 | } else {
|
250 |
| - // range_tensor is used to store the range for each embedding row. |
251 |
| - // We save max_pos/max_val(rowwise) as row scale to quantize |
252 |
| - // unlike INT8, FP8 does not have zero shift |
253 |
| - // This will guarantee the numerical match but bring some perf |
254 |
| - // regression. |
255 |
| - auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat)); |
256 |
| - |
257 | 370 | {
|
258 | 371 | // we need a blockDim.x that is a power of 2 no larger than the warp size
|
259 | 372 | // of 32
|
@@ -289,27 +402,63 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
|
289 | 402 | }
|
290 | 403 |
|
291 | 404 | {
|
292 |
| - const int blockDim_x = |
293 |
| - std::min(ncols, static_cast<int64_t>(threads_per_block)); |
294 |
| - dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); |
295 |
| - const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); |
296 |
| - const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); |
297 |
| - dim3 gridDim(gridDim_x, gridDim_y); |
298 |
| - |
299 |
| - FBGEMM_DISPATCH_FLOATING_TYPES( |
300 |
| - input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] { |
301 |
| - FBGEMM_LAUNCH_KERNEL( |
302 |
| - (_compute_FP8_quantize_cuda_kernel<scalar_t>), |
303 |
| - gridDim, |
304 |
| - blockDim, |
305 |
| - 0, |
306 |
| - at::cuda::getCurrentCUDAStream(), |
307 |
| - PTA_B(input_1D, scalar_t, 1, 64), |
308 |
| - nrows, |
309 |
| - ncols, |
310 |
| - PTA_B(output_1D, uint8_t, 1, 64), |
311 |
| - forward); |
312 |
| - }); |
| 405 | + const uintptr_t addr = reinterpret_cast<uintptr_t>(&input); |
| 406 | + |
| 407 | + const static bool use_vectorization = |
| 408 | + ((addr % 16) == 0) && |
| 409 | + !config::is_feature_enabled( |
| 410 | + config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION); |
| 411 | + |
| 412 | + const constexpr int vec_size = VectorSizeTraits<input_t>::value; |
| 413 | + if (use_vectorization) { |
| 414 | + const int block_y = 64; |
| 415 | + const int blockDim_y = ncols > vec_size ? block_y : 1; |
| 416 | + |
| 417 | + dim3 blockDim(threads_per_block / blockDim_y, blockDim_y); |
| 418 | + const auto gridDim_x = cuda_calc_xblock_count(nrows, blockDim.x); |
| 419 | + const auto gridDim_y = cuda_calc_block_count( |
| 420 | + (ncols + vec_size - 1) / vec_size, blockDim.y); |
| 421 | + dim3 gridDim(gridDim_x, gridDim_y); |
| 422 | + |
| 423 | + FBGEMM_DISPATCH_FLOATING_TYPES( |
| 424 | + input.scalar_type(), |
| 425 | + "_compute_FP8_quantize_cuda_vectorized_kernel", |
| 426 | + [&] { |
| 427 | + FBGEMM_LAUNCH_KERNEL( |
| 428 | + (_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>), |
| 429 | + gridDim, |
| 430 | + blockDim, |
| 431 | + 0, |
| 432 | + at::cuda::getCurrentCUDAStream(), |
| 433 | + PTA_B(input_1D, scalar_t, 1, 64), |
| 434 | + nrows, |
| 435 | + ncols, |
| 436 | + PTA_B(output_1D, uint8_t, 1, 64), |
| 437 | + forward); |
| 438 | + }); |
| 439 | + } else { |
| 440 | + const int blockDim_x = |
| 441 | + std::min(ncols, static_cast<int64_t>(threads_per_block)); |
| 442 | + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); |
| 443 | + const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); |
| 444 | + const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); |
| 445 | + dim3 gridDim(gridDim_x, gridDim_y); |
| 446 | + |
| 447 | + FBGEMM_DISPATCH_FLOATING_TYPES( |
| 448 | + input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] { |
| 449 | + FBGEMM_LAUNCH_KERNEL( |
| 450 | + (_compute_FP8_quantize_cuda_kernel<scalar_t>), |
| 451 | + gridDim, |
| 452 | + blockDim, |
| 453 | + 0, |
| 454 | + at::cuda::getCurrentCUDAStream(), |
| 455 | + PTA_B(input_1D, scalar_t, 1, 64), |
| 456 | + nrows, |
| 457 | + ncols, |
| 458 | + PTA_B(output_1D, uint8_t, 1, 64), |
| 459 | + forward); |
| 460 | + }); |
| 461 | + } |
313 | 462 | }
|
314 | 463 | }
|
315 | 464 |
|
@@ -358,8 +507,8 @@ Tensor _FP8rowwise_to_float_gpu_t(
|
358 | 507 | // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
|
359 | 508 | // data residing in global memory compiles to a single global memory
|
360 | 509 | // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16
|
361 |
| - // bytes and the data is naturally aligned (i.e., its address is a multiple of |
362 |
| - // that size). |
| 510 | + // bytes and the data is naturally aligned (i.e., its address is a multiple |
| 511 | + // of that size). |
363 | 512 | auto output_dims = input_sizes.vec();
|
364 | 513 | output_dims[last_dim] = output_columns;
|
365 | 514 | const auto output_sdtype = static_cast<SparseType>(output_dtype);
|
|
0 commit comments