Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,6 @@ gpu_cpp_library(
fbgemm_gpu_tbe_cache
fbgemm_gpu_tbe_optimizers
fbgemm_gpu_tbe_utils
fbgemm_gpu_config
DESTINATION
fbgemm_gpu)
3 changes: 3 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/config/feature_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def foo():
# Enable bounds_check_indices_v2
BOUNDS_CHECK_INDICES_V2 = auto()

# disable fp8 quant vectorization
DISABLE_FP8_QUANT_VECTORIZATION = auto()

# Enable TBE input parameters extraction
TBE_REPORT_INPUT_PARAMS = auto()

Expand Down
15 changes: 8 additions & 7 deletions fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ namespace fbgemm_gpu::config {
/// UI.
///
/// For OSS: The environment variable will be evaluated as f"FBGEMM_{ENUM}"
#define ENUMERATE_ALL_FEATURE_FLAGS \
X(TBE_V2) \
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
X(TBE_ANNOTATE_KINETO_TRACE) \
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
X(BOUNDS_CHECK_INDICES_V2) \
#define ENUMERATE_ALL_FEATURE_FLAGS \
X(TBE_V2) \
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
X(TBE_ANNOTATE_KINETO_TRACE) \
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
X(BOUNDS_CHECK_INDICES_V2) \
X(DISABLE_FP8_QUANT_VECTORIZATION) \
X(TBE_REPORT_INPUT_PARAMS)
// X(EXAMPLE_FEATURE_FLAG)

Expand Down
209 changes: 179 additions & 30 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include "common.cuh"
#include "fbgemm_gpu/config/feature_gates.h"

using Tensor = at::Tensor;

Expand Down Expand Up @@ -157,6 +158,125 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
}
}

template <typename scalar_t>
struct VectorSizeTraits {
// Default to 4 elements for most types (16 bytes for float)
static constexpr int value = 4;
};

// Specialization for half (float16)
template <>
struct VectorSizeTraits<c10::Half> {
// 8 elements for half precision (16 bytes total)
static constexpr int value = 8;
};

// Specialization for __nv_bfloat16
template <>
struct VectorSizeTraits<c10::BFloat16> {
// 8 elements for bfloat16 precision (16 bytes total)
static constexpr int value = 8;
};

// aligned vector generates vectorized load/store on CUDA (copy-pasted from
// MemoryAccess.cuh)
template <typename scalar_t, int vec_size = VectorSizeTraits<scalar_t>::value>
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
scalar_t val[vec_size];
};

template <typename input_t>
#ifndef USE_ROCM
__global__ __attribute__((maxrregcount(32))) inline void
#else
__global__ inline void
#endif
_compute_FP8_quantize_cuda_vectorized_kernel(
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
const int64_t nrows,
const int64_t ncols,
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output,
const bool forward) {
// Calculate global row index with 2D thread blocks
const int64_t gx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t thread_idx = blockIdx.y * blockDim.y + threadIdx.y;
static constexpr int vec_size = VectorSizeTraits<input_t>::value;
// Early return if row is out of bounds
if (gx >= nrows || (thread_idx * vec_size) >= ncols) {
return;
}

int ebit = forward ? 4 : 5;
int bias = forward ? 15 : 31;
float max_pos = forward ? 0.9375 : 0.875;

// Calculate output width
const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const auto output_columns = ncols_aligned + 2 * sizeof(float);

// Calculate base offsets for the current row
const int64_t input_row_offset = gx * ncols;
const int64_t output_row_offset = gx * output_columns;

// Calculate the position where the scale values are stored
const int64_t scale_offset = output_row_offset + ncols_aligned;
const float scale_value = reinterpret_cast<float*>(&output[scale_offset])[0];

const int64_t vector_blocks = ncols / vec_size;

using vec_t = aligned_vector<input_t, vec_size>;
using vec_i = aligned_vector<uint8_t, vec_size>;

const int64_t col_idx = thread_idx * vec_size;

// The if else here garantee the kernel works for aligned/misaligned
// cases. When ncols is not multiple of vec_size, then we can't dereference
// the pointer, and we access one by one, this trigger multiple trips to
// global memory, but is still faster than the original kernel.
if ((col_idx + (vec_size - 1) < ncols) && ((ncols % vec_size) == 0)) {
// Load vec_size elements - handle both aligned and unaligned cases
// correctly
const vec_t input_row =
*reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]);

vec_i* output_row =
reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]);

// // Create temporary vector to enable vectorized store
vec_i temp_output;
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
temp_output.val[i] = float_to_hfp8(
to_float(input_row.val[i]) * scale_value, ebit, bias, max_pos);
}
*output_row = temp_output;
} else if ((col_idx + (vec_size - 1) < ncols)) {
// correctly
const vec_t* input_row =
reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]);

vec_i* output_row =
reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]);
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
output_row->val[i] = float_to_hfp8(
to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos);
}
}

// 2. Process any remaining elements (less than vec_size) with scalar
// operations
const int64_t remaining_start = vector_blocks * vec_size;
for (int64_t col = remaining_start + threadIdx.y; col < ncols;
col += blockDim.y) {
output[output_row_offset + col] = float_to_hfp8(
to_float(input[input_row_offset + col]) * scale_value,
ebit,
bias,
max_pos);
}
}

template <typename output_t>
__global__ inline void _FP8rowwise_to_float_cuda_kernel(
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
Expand Down Expand Up @@ -247,13 +367,6 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
forward);
});
} else {
// range_tensor is used to store the range for each embedding row.
// We save max_pos/max_val(rowwise) as row scale to quantize
// unlike INT8, FP8 does not have zero shift
// This will guarantee the numerical match but bring some perf
// regression.
auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat));

{
// we need a blockDim.x that is a power of 2 no larger than the warp size
// of 32
Expand Down Expand Up @@ -289,27 +402,63 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
}

{
const int blockDim_x =
std::min(ncols, static_cast<int64_t>(threads_per_block));
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
(_compute_FP8_quantize_cuda_kernel<scalar_t>),
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(input_1D, scalar_t, 1, 64),
nrows,
ncols,
PTA_B(output_1D, uint8_t, 1, 64),
forward);
});
const uintptr_t addr = reinterpret_cast<uintptr_t>(&input);

const static bool use_vectorization =
((addr % 16) == 0) &&
!config::is_feature_enabled(
config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION);

const constexpr int vec_size = VectorSizeTraits<input_t>::value;
if (use_vectorization) {
const int block_y = 64;
const int blockDim_y = ncols > vec_size ? block_y : 1;

dim3 blockDim(threads_per_block / blockDim_y, blockDim_y);
const auto gridDim_x = cuda_calc_xblock_count(nrows, blockDim.x);
const auto gridDim_y = cuda_calc_block_count(
(ncols + vec_size - 1) / vec_size, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"_compute_FP8_quantize_cuda_vectorized_kernel",
[&] {
FBGEMM_LAUNCH_KERNEL(
(_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>),
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(input_1D, scalar_t, 1, 64),
nrows,
ncols,
PTA_B(output_1D, uint8_t, 1, 64),
forward);
});
} else {
const int blockDim_x =
std::min(ncols, static_cast<int64_t>(threads_per_block));
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
(_compute_FP8_quantize_cuda_kernel<scalar_t>),
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(input_1D, scalar_t, 1, 64),
nrows,
ncols,
PTA_B(output_1D, uint8_t, 1, 64),
forward);
});
}
}
}

Expand Down Expand Up @@ -358,8 +507,8 @@ Tensor _FP8rowwise_to_float_gpu_t(
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
// data residing in global memory compiles to a single global memory
// instruction if and only if the size of the data type is 1, 2, 4, 8, or 16
// bytes and the data is naturally aligned (i.e., its address is a multiple of
// that size).
// bytes and the data is naturally aligned (i.e., its address is a multiple
// of that size).
auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;
const auto output_sdtype = static_cast<SparseType>(output_dtype);
Expand Down
Loading