diff --git a/docs/envvars.rst b/docs/envvars.rst index bd62ccac46..044a7f6a0d 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -303,13 +303,13 @@ Kernel Configuration :Type: ``str`` (``MAE`` or ``MSE``) :Default: ``MAE`` - :Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. + :Description: Select the error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. .. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. + :Description: Use the faster NVFP4 4over6 candidate error path that compares candidates in the E4M3-scaled domain after the E2M1 x E4M3 product is rounded to FP16. Error differences and accumulation remain FP32. By default, 4over6 error comparison uses the original input-domain path; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 5bb92f70dc..ea60bd3837 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -2,6 +2,10 @@ # # See LICENSE for license information. +import os +from contextlib import contextmanager +from dataclasses import dataclass + import pytest import torch import transformer_engine.pytorch as te @@ -16,6 +20,59 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +@dataclass(frozen=True) +class NVFP44Over6TestConfig: + id: str + use_4over6: bool = True + e4m3_max: int = 448 + err_mode: str = "MAE" + err_use_fast_math: bool = False + + +NVFP4_4OVER6_CONFIGS = [ + NVFP44Over6TestConfig(id="nvfp4", use_4over6=False), + NVFP44Over6TestConfig(id="4over6-mae-e4m3-448-exact", err_mode="MAE"), + NVFP44Over6TestConfig( + id="4over6-mae-e4m3-448-err-fast", + err_mode="MAE", + err_use_fast_math=True, + ), + NVFP44Over6TestConfig(id="4over6-mae-e4m3-256-exact", e4m3_max=256, err_mode="MAE"), + NVFP44Over6TestConfig( + id="4over6-mae-e4m3-256-err-fast", + e4m3_max=256, + err_mode="MAE", + err_use_fast_math=True, + ), + NVFP44Over6TestConfig(id="4over6-mse-e4m3-448-exact", err_mode="MSE"), + NVFP44Over6TestConfig( + id="4over6-mse-e4m3-448-err-fast", + err_mode="MSE", + err_use_fast_math=True, + ), + NVFP44Over6TestConfig(id="4over6-mse-e4m3-256-exact", e4m3_max=256, err_mode="MSE"), + NVFP44Over6TestConfig( + id="4over6-mse-e4m3-256-err-fast", + e4m3_max=256, + err_mode="MSE", + err_use_fast_math=True, + ), +] + + +@contextmanager +def nvfp4_4over6_err_fast_math(enabled: bool): + old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" + try: + yield + finally: + if old_value is None: + os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None) + else: + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value + + def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, @@ -55,6 +112,7 @@ def check_quantization_nvfp4_versus_reference( use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", + nvfp4_4over6_err_use_fast_math: bool = False, ) -> None: if nvfp4_e4m3_max != 448 and not use_4over6: pytest.skip("E4M3 max 256 is only meaningful for 4over6") @@ -87,13 +145,24 @@ def check_quantization_nvfp4_versus_reference( nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x) + + if use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) # Extract data from NVFP4Tensor assert x_nvfp4_sut._rowwise_data is not None @@ -122,6 +191,7 @@ def check_quantization_nvfp4_versus_reference( nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -197,9 +267,11 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, +) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -209,9 +281,7 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_e4m3_max: int, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -222,9 +292,10 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) @@ -242,8 +313,11 @@ def test_quantization_block_tiling_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, +) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -252,11 +326,10 @@ def test_nvfp4_quantization_extrema_versus_reference( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ): maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + row_scaled_nvfp4, return_transpose, use_4over6=nvfp4_4over6_config.use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -280,17 +353,28 @@ def test_nvfp4_quantization_extrema_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x) + if nvfp4_4over6_config.use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_config.err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) assert x_nvfp4_sut._rowwise_data is not None qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) @@ -313,8 +397,10 @@ def test_nvfp4_quantization_extrema_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -359,8 +445,11 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, +) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -368,8 +457,7 @@ def test_nvfp4_quantization_boundary_values( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -377,7 +465,7 @@ def test_nvfp4_quantization_boundary_values( Validates native vs reference byte-for-byte and scale parity. """ maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + row_scaled_nvfp4, return_transpose, use_4over6=nvfp4_4over6_config.use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -410,17 +498,28 @@ def test_nvfp4_quantization_boundary_values( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x) + if nvfp4_4over6_config.use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_config.err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) assert x_nvfp4_sut._rowwise_data is not None qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) @@ -443,8 +542,10 @@ def test_nvfp4_quantization_boundary_values( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -489,8 +590,11 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, +) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, @@ -498,11 +602,10 @@ def test_nvfp4_quantization_noncontiguous_inputs( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ): maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + row_scaled_nvfp4, return_transpose, use_4over6=nvfp4_4over6_config.use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -526,17 +629,28 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x_nc) + if nvfp4_4over6_config.use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_config.err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - x_nc.shape, dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) assert x_nvfp4_sut._rowwise_data is not None qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) @@ -559,8 +673,10 @@ def test_nvfp4_quantization_noncontiguous_inputs( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index b6057370dc..50776a3ed6 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -102,6 +102,12 @@ struct ScalePair { nvfp4_scale_t map6; float inv_map4; float inv_map6; + float global_encode_scale; +}; + +struct FP16ErrorScalePair { + uint32_t map4; + uint32_t map6; }; template @@ -116,18 +122,6 @@ __device__ __forceinline__ float compute_error_rn(const float diff) { } } -template -__device__ __forceinline__ float compute_error(const float diff) { - if constexpr (kMode == kNVTENVFP44Over6MinMSE) { - return diff * diff; - } else if constexpr (kMode == kNVTENVFP44Over6MinMAE) { - return fabsf(diff); - } else { - NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); - return fabsf(diff); - } -} - template __device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax, const float global_amax) { @@ -147,6 +141,7 @@ __device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax, fminf(1.0f / (static_cast(scales.map4) * S_dec), detail::TypeExtrema::max); scales.inv_map6 = fminf(1.0f / (static_cast(scales.map6) * S_dec), detail::TypeExtrema::max); + scales.global_encode_scale = S_enc; return scales; } @@ -200,26 +195,73 @@ __device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_ constexpr float fp8_max = static_cast(E4M3_MAX); constexpr float err_denom = fp4_max * fp8_max; const uint16_t half_bits = (dequant_bits >> SHIFT) & 0xFFFF; + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); + const float diff = __fsub_rn(val, x); + *err = __fadd_rn(*err, compute_error_rn(diff)); +} - if constexpr (Cfg::err_use_fast_math) { - const float dequant = __half2float(__ushort_as_half(half_bits)); - const float val = dequant * sf * global_amax / err_denom; - const float diff = val - x; - *err += compute_error(diff); - } else { - const float dequant = __half2float(__ushort_as_half(half_bits)); - const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); - const float diff = __fsub_rn(val, x); - *err = __fadd_rn(*err, compute_error_rn(diff)); - } +__device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) { + return *reinterpret_cast(&sf); } -template -__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8], - const float block_scale_inverse, - const nvfp4_scale_t sf, - const float global_amax, +__device__ __forceinline__ FP16ErrorScalePair compute_fp16_error_scales(const ScalePair &scales) { + FP16ErrorScalePair result; + const uint32_t packed_scales = static_cast(fp8_bits(scales.map4)) | + (static_cast(fp8_bits(scales.map6)) << 8); + asm volatile( + "{\n" + ".reg .b16 fp8_pair;\n" + ".reg .b16 map4_h, map6_h;\n" + ".reg .b32 scale_h2;\n" + "cvt.u16.u32 fp8_pair, %2;\n" + "cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n" + "mov.b32 {map4_h, map6_h}, scale_h2;\n" + "mov.b32 %0, {map4_h, map4_h};\n" + "mov.b32 %1, {map6_h, map6_h};\n" + "}" + : "=r"(result.map4), "=r"(result.map6) + : "r"(packed_scales)); + return result; +} + +__device__ __forceinline__ float2 f16x2_scaled_to_float2(const uint32_t q_h2, + const uint32_t scale_h2) { + float2 result; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + ".reg .b32 prod_h2;\n" + "mul.rn.f16x2 prod_h2, %2, %3;\n" + "mov.b32 {lo, hi}, prod_h2;\n" + "cvt.f32.f16 %0, lo;\n" + "cvt.f32.f16 %1, hi;\n" + "}" + : "=f"(result.x), "=f"(result.y) + : "r"(q_h2), "r"(scale_h2)); + return result; +} + +template +__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t q_h2, + const float x0, const float x1, + const uint32_t scale_h2, + const float global_encode_scale, float *err) { + const float2 candidate = f16x2_scaled_to_float2(q_h2, scale_h2); + const float original0 = __fmul_rn(x0, global_encode_scale); + const float original1 = __fmul_rn(x1, global_encode_scale); + const float diff0 = __fsub_rn(candidate.x, original0); + const float diff1 = __fsub_rn(candidate.y, original1); + *err = __fadd_rn(*err, compute_error_rn(diff0)); + *err = __fadd_rn(*err, compute_error_rn(diff1)); +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( + const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t sf, + const uint32_t fp16_error_scale, const float global_amax, const float global_encode_scale, + float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -253,15 +295,26 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (& "Try recompiling with sm_XXXa instead of sm_XXX."); } - const float sf_float = static_cast(sf); - accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); + if constexpr (Cfg::err_use_fast_math) { + accumulate_fp16_scaled_error_pair(out_dequant_1, x[0], x[1], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_2, x[2], x[3], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_3, x[4], x[5], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_4, x[6], x[7], fp16_error_scale, + global_encode_scale, err); + } else { + const float sf_float = static_cast(sf); + accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); + } return out; } @@ -272,14 +325,22 @@ __device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], c CandidatePair candidates; candidates.map4.err = 0.0f; candidates.map6.err = 0.0f; + FP16ErrorScalePair fp16_error_scales{}; + if constexpr (Cfg::err_use_fast_math) { + fp16_error_scales = compute_fp16_error_scales(scales); + } candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + x0, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, + scales.global_encode_scale, &candidates.map4.err); candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + x0, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, + scales.global_encode_scale, &candidates.map6.err); candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + x1, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, + scales.global_encode_scale, &candidates.map4.err); candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + x1, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, + scales.global_encode_scale, &candidates.map6.err); return candidates; } diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 5c23c76703..d09b95ace3 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -355,12 +355,11 @@ def __init__( nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", + nvfp4_4over6_err_use_fast_math: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() - if nvfp4_4over6_err_mode not in ("MAE", "MSE"): - raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -369,6 +368,8 @@ def __init__( "Row-scaled NVFP4 reference quantization does not support columnwise usage." ) if nvfp4_use_4over6: + if nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError(f"Unsupported NVFP4 4over6 error mode: {nvfp4_4over6_err_mode}.") if pow_2_scales: raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") if quant_tile_shape not in ((1, 16), (16, 16)): @@ -386,6 +387,7 @@ def __init__( if self.nvfp4_e4m3_max not in (448, 256): raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode + self.nvfp4_4over6_err_use_fast_math = nvfp4_4over6_err_use_fast_math self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -464,6 +466,64 @@ def _recover_swizzled_scales( result = torch.reshape(tmp, (rounded_m, rounded_n)) return result[:m, :scale_n] + @staticmethod + def _ref_nvfp4_4over6_fp16_candidate(q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Decode E2M1 x E4M3 with the kernel's FP16 product semantics.""" + q_float = q.to(torch.float32) + q_sign = (q_float < 0).to(torch.int32) + q_sig = (torch.abs(q_float) * 2).to(torch.int32) + + scale_code = scale.contiguous().view(torch.uint8).to(torch.int32) + scale_sign = scale_code >> 7 + scale_exp_field = (scale_code >> 3) & 0xF + scale_mantissa = scale_code & 0x7 + scale_sig = torch.where(scale_exp_field == 0, scale_mantissa, scale_mantissa + 8) + scale_exp2 = torch.where(scale_exp_field == 0, scale_exp_field - 9, scale_exp_field - 10) + + product_sign = q_sign ^ scale_sign + product_sig = q_sig * scale_sig + product_exp2 = scale_exp2 - 1 + + log2_sig = torch.zeros_like(product_sig) + for threshold in (2, 4, 8, 16, 32, 64, 128, 256): + log2_sig = log2_sig + (product_sig >= threshold).to(torch.int32) + + floor_exp = log2_sig + product_exp2 + normal_bits = ((floor_exp + 15) << 10) | ( + torch.bitwise_left_shift(product_sig, 10 - log2_sig) - 1024 + ) + subnormal_bits = torch.bitwise_left_shift(product_sig, product_exp2 + 24) + magnitude_bits = torch.where(floor_exp < -14, subnormal_bits, normal_bits) + prod_bits = (product_sign << 15) | magnitude_bits + prod_bits = torch.where(product_sig == 0, product_sign << 15, prod_bits) + prod_bits = torch.where( + (scale_code & 0x7F) == 0x7F, + torch.full_like(prod_bits, 0x7E00), + prod_bits, + ) + + sign_f32 = torch.where( + (prod_bits & 0x8000) != 0, + torch.tensor(-1.0, device=prod_bits.device, dtype=torch.float32), + torch.tensor(1.0, device=prod_bits.device, dtype=torch.float32), + ) + fp16_exp = (prod_bits >> 10) & 0x1F + fp16_frac = prod_bits & 0x3FF + normal_f32 = torch.ldexp((fp16_frac + 1024).to(torch.float32), fp16_exp - 25) + subnormal_f32 = torch.ldexp(fp16_frac.to(torch.float32), fp16_exp - 24) + return sign_f32 * torch.where(fp16_exp == 0, subnormal_f32, normal_f32) + + @staticmethod + def _sum_4over6_2d_error(err: torch.Tensor, tile_len_y: int) -> torch.Tensor: + """Reduce 16 row errors in the same tree order as the CUDA warp reduction.""" + assert tile_len_y == 16, "NVFP4 4over6 2D error reduction expects 16 rows." + rows = err.view(err.shape[0] // tile_len_y, tile_len_y, err.shape[1], 1) + rows = rows.squeeze(-1) + rows = rows[:, 0:8, :] + rows[:, 8:16, :] + rows = rows[:, 0:4, :] + rows[:, 4:8, :] + rows = rows[:, 0:2, :] + rows[:, 2:4, :] + return (rows[:, 0, :] + rows[:, 1, :]).unsqueeze(-1) + @staticmethod def _quantize_blockwise_4over6_reference( x: torch.Tensor, @@ -474,13 +534,15 @@ def _quantize_blockwise_4over6_reference( row_scaled_nvfp4: bool, tile_len_y: int, nvfp4_4over6_err_mode: str, + nvfp4_4over6_err_use_fast_math: bool, nvfp4_e4m3_max: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize NVFP4 with 4over6 candidate selection. This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, - the configured error is computed in the original input domain with the - selected global E4M3 denominator, and ties choose map-to-6. + MAE/MSE compute error in the original input domain by default, the + fast-math error path computes error in the E4M3-scaled FP16 product + domain, and ties choose map-to-6. """ m, num_blocks, tile_len_x = x.shape n = num_blocks * tile_len_x @@ -527,41 +589,59 @@ def _quantize_blockwise_4over6_reference( qx_map4 = cast_to_fp4x2(clipped_x_map4) qx_map6 = cast_to_fp4x2(clipped_x_map6) + err_map4 = torch.zeros_like(vec_max) + err_map6 = torch.zeros_like(vec_max) fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view(m, num_blocks, tile_len_x) fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view(m, num_blocks, tile_len_x) - denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX - sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) - sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) - if row_scaled_nvfp4: - error_global_amax = global_amax.squeeze(-1) - else: - error_global_amax = global_amax x_float = x.to(torch.float32) - err_map4 = torch.zeros_like(vec_max) - err_map6 = torch.zeros_like(vec_max) - for idx in range(tile_len_x): - val_map4 = fp4_map4[:, :, idx] * sf_map4 - val_map4 = val_map4 * error_global_amax - val_map4 = val_map4 / denom - diff_map4 = val_map4 - x_float[:, :, idx] - if nvfp4_4over6_err_mode == "MSE": - err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) - else: - err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) - - val_map6 = fp4_map6[:, :, idx] * sf_map6 - val_map6 = val_map6 * error_global_amax - val_map6 = val_map6 / denom - diff_map6 = val_map6 - x_float[:, :, idx] - if nvfp4_4over6_err_mode == "MSE": - err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + if nvfp4_4over6_err_use_fast_math: + original_scaled = x_float * global_encode_scale + candidate_map4 = NVFP4QuantizerRef._ref_nvfp4_4over6_fp16_candidate( + fp4_map4, decode_scale_map4 + ) + candidate_map6 = NVFP4QuantizerRef._ref_nvfp4_4over6_fp16_candidate( + fp4_map6, decode_scale_map6 + ) + for idx in range(tile_len_x): + diff_map4 = candidate_map4[:, :, idx] - original_scaled[:, :, idx] + diff_map6 = candidate_map6[:, :, idx] - original_scaled[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) + else: + denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX + sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) + sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) + if row_scaled_nvfp4: + error_global_amax = global_amax.squeeze(-1) else: - err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) + error_global_amax = global_amax + for idx in range(tile_len_x): + val_map4 = fp4_map4[:, :, idx] * sf_map4 + val_map4 = val_map4 * error_global_amax + val_map4 = val_map4 / denom + diff_map4 = val_map4 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) + + val_map6 = fp4_map6[:, :, idx] * sf_map6 + val_map6 = val_map6 * error_global_amax + val_map6 = val_map6 / denom + diff_map6 = val_map6 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) if tile_len_y == 1: pick_map4 = err_map4 < err_map6 else: - err_map4_blocks = err_map4.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) - err_map6_blocks = err_map6.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + err_map4_blocks = NVFP4QuantizerRef._sum_4over6_2d_error(err_map4, tile_len_y) + err_map6_blocks = NVFP4QuantizerRef._sum_4over6_2d_error(err_map6, tile_len_y) pick_map4 = (err_map4_blocks < err_map6_blocks).repeat_interleave(tile_len_y, dim=0) qx = torch.where( pick_map4.expand(-1, -1, tile_len_x // 2), @@ -584,6 +664,7 @@ def _quantize_blockwise_reference( nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", + nvfp4_4over6_err_use_fast_math: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -653,7 +734,7 @@ def _quantize_blockwise_reference( global_decode_scale = torch.div(1.0, global_encode_scale) if nvfp4_use_4over6: # FourOverSix compares map-to-4 and map-to-6 candidates using - # the configured original input-domain error, while keeping TE-style FP4 + # the configured error mode, while keeping TE-style FP4 # quantization for each candidate. return cls._quantize_blockwise_4over6_reference( x, @@ -664,6 +745,7 @@ def _quantize_blockwise_reference( row_scaled_nvfp4, tile_len_y, nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math, nvfp4_e4m3_max, ) @@ -830,6 +912,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=self.nvfp4_4over6_err_use_fast_math, eps=self.eps, ) if transpose_scales: @@ -856,6 +939,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=self.nvfp4_4over6_err_use_fast_math, eps=self.eps, )