Skip to content

Quantization support for GroupedTensor: FP8 per-tensor#3102

Open
int-smart wants to merge 8 commits into
NVIDIA:mainfrom
int-smart:feature/fp8_quant
Open

Quantization support for GroupedTensor: FP8 per-tensor#3102
int-smart wants to merge 8 commits into
NVIDIA:mainfrom
int-smart:feature/fp8_quant

Conversation

@int-smart

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes #2449

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

Kernels: Extended unary_kernel and unary_grad_kernel in vectorized_pointwise.h to dynamically support per-tensor scale, scale_inv, and amax for grouped tensors.
Alignment: Aligned the random padding in test_common.cu to a constant 64 elements to guarantee matching element offsets between input and output grouped tensors.
Verification: Corrected the FP8 cast validation loop in test_cast_fp8_grouped.cu to compare raw quantized values directly, resolving false test failures caused by rounding errors.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

int-smart added 2 commits June 6, 2026 11:55
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@int-smart int-smart requested a review from Oleg-Goncharov as a code owner June 7, 2026 05:05
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 7, 2026
@greptile-apps

greptile-apps Bot commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds FP8 per-tensor (delayed) scaling support to GroupedTensor for both forward quantize and dequantize paths. The implementation extends unary_kernel and unary_grad_kernel to accept per-tensor offset/scale/amax/scale_inv arrays and launches a 2-D grid (dim3(num_blocks, num_tensors)) so each blockIdx.y handles exactly one sub-tensor without a kernel-per-tensor loop.

  • Kernel changes (vectorized_pointwise.h): Added per-tensor scale/amax/scale_inv indexing, correct sub-tensor-scoped VectorizedLoader/Storer, and atomicMaxFloat into per-tensor device memory; previously flagged fixed-size block_max[64] array and in-loop reduce_max race are gone from the new implementation.
  • Dispatch wiring (dispatch/quantize.cuh, dispatch/dequantize.cuh): NVTE_DELAYED_TENSOR_SCALING cases added to both forward and backward grouped helpers.
  • Test infrastructure (test_common.cu/h): Constant 64-element padding replaces element-size-derived padding so input and output grouped buffers always share identical element offsets; per-tensor scale and amax device buffers added; new grouped cast and dequantize tests added.

Confidence Score: 4/5

The grouped FP8 quantize/dequantize kernels are functionally correct for the tested shapes; the main non-trivial concern is a performance footgun for large MoE models.

The previously flagged critical issues (fixed-size block_max[64] array and in-loop reduce_max shared-memory race) are no longer present in the current implementation — amax is accumulated via indexed atomicMaxFloat into device memory and reduce_max is called once per block. The remaining findings are non-correctness issues: the 2-D grid launches num_tensors× more x-blocks than each sub-tensor needs (excess blocks immediately idle out), and the dead find_tensor_id function. The 64-element padding fix in the test infrastructure correctly ensures vectorization-alignment is preserved across all sub-tensor boundary offsets.

transformer_engine/common/util/vectorized_pointwise.h — contains the over-provisioned grid launch and the unused find_tensor_id device function.

Important Files Changed

Filename Overview
transformer_engine/common/util/vectorized_pointwise.h Core kernel changes: unary_kernel and unary_grad_kernel extended with gridDim.y-based per-tensor dispatch, per-tensor scale/amax/scale_inv indexing, and correct size-scoped loaders. Dead find_tensor_id helper and grid over-provisioning are the notable P2 concerns.
transformer_engine/common/cast/fp8/quantize_fp8.cuh New group_quantize overloads (fwd + bwd) dispatch to VectorizedUnaryKernelLauncher with per-tensor scale/amax/scale_inv numel args. Backward overload derives N and offsets from input rather than grad, flagged in a previous review comment.
transformer_engine/common/cast/fp8/dequantize_fp8.cuh New group_dequantize function passes fp8 input's scale_inv as the read-side multiplier, passing scale=nullptr and amax=nullptr. Logic is clean and mirrors the grouped quantize structure.
transformer_engine/common/cast/dispatch/quantize.cuh Added NVTE_DELAYED_TENSOR_SCALING cases in both fwd and bwd helpers; dispatch wiring looks correct.
transformer_engine/common/cast/dispatch/dequantize.cuh Added NVTE_DELAYED_TENSOR_SCALING case to group_dequantize_helper; minimal and correct.
tests/cpp/test_common.cu Padding alignment changed to constant 64 elements; scale and amax device buffers added for NVTE_DELAYED_TENSOR_SCALING; RNG moved out of static to reset per call so input and output groups produce identical offsets.
tests/cpp/test_common.h Added scale field to GroupedBuffers struct to hold per-tensor forward scales; straightforward struct extension.
transformer_engine/common/transformer_engine.cpp Validation relaxed from exact-match to lower-bound on data.numel vs logical_shape to accommodate padding bytes in the grouped buffer allocation.
tests/cpp/operator/test_cast_fp8_grouped.cu New test: copies per-tensor scale_inv and amax back from grouped buffers and compares raw quantized FP8 values by round-tripping through OutputType.
tests/cpp/operator/test_dequantize_fp8_grouped.cu New test: validates grouped FP8 dequantize by computing val * scale_inv reference on CPU and comparing to GPU output across E4M3/E5M2 input types.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant Dispatch as dispatch/quantize.cuh
    participant FP8Q as fp8/quantize_fp8.cuh
    participant Launcher as VectorizedUnaryKernelLauncher
    participant Kernel as unary_kernel

    Caller->>Dispatch: nvte_group_quantize(input, output)
    Dispatch->>FP8Q: DELAYED_TENSOR_SCALING case
    FP8Q->>FP8Q: N from input.data.shape
    FP8Q->>Launcher: offsets, first_dims, last_dims, num_tensors
    Launcher->>Launcher: "grid = dim3(num_blocks, num_tensors)"
    Launcher->>Kernel: launch 2D grid

    loop "blockIdx.y = tensor_id"
        Kernel->>Kernel: "start = offsets[tensor_id]"
        Kernel->>Kernel: "size = first_dims x last_dims"
        Kernel->>Kernel: scale by scale[tensor_id]
        Kernel->>Kernel: write scale_inv[tensor_id] at global_idx 0
        Kernel->>Kernel: atomicMaxFloat to amax[tensor_id]
    end
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

}
const int warp_id = threadIdx.x / THREADS_PER_WARP;

float block_max[64] = {0.0f};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fixed-size block_max array overflows when num_tensors > 64

block_max is indexed by tensor_id which is bounded by num_tensors (a runtime kernel parameter), yet the array is fixed at size 64. When a grouped tensor has more than 64 sub-tensors — common in large MoE models — any thread whose tensor_id >= 64 writes past the end of the array, corrupting other stack/local variables (including warp_id, loop variables, max, etc.) and producing silently wrong quantized outputs or GPU faults. The same defect exists in the unary_grad_kernel at the corresponding location.

Comment on lines +282 to +290
if (offsets != nullptr || num_tensors > 1) {
for (size_t t = 0; t < num_tensors; ++t) {
float t_max = block_max[t];
t_max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(t_max, warp_id);
if (threadIdx.x == 0 && t_max > 0.0f) {
size_t amax_idx = (amax_numel == num_tensors) ? t : 0;
atomicMaxFloat(&amax[amax_idx], t_max);
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shared-memory race in per-tensor amax loop

reduce_max uses a __shared__ float staging[num_warps] array and a single __syncthreads() that ensures visibility before warp 0 reads staging. However, reduce_max does NOT call __syncthreads() after warp 0's read before returning. Calling it in a loop means that warp 1 (and other non-zero warps) can reach the staging[warpid] = my_warp_max write for iteration t+1 before warp 0 finishes reading staging[1] for iteration t. Without an explicit barrier between iterations, the CUDA memory model does not guarantee ordering, so warp 0 can read a partially updated staging[1] and compute an incorrect per-tensor amax. A __syncthreads() is needed after each call to reduce_max in this loop (and the identical loop in unary_grad_kernel).

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@int-smart int-smart force-pushed the feature/fp8_quant branch from 0432d95 to d5fb0bf Compare June 9, 2026 06:40
pre-commit-ci Bot and others added 4 commits June 9, 2026 06:40
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quantization support for GroupedTensor: FP8 per-tensor

1 participant