Skip to content
4 changes: 2 additions & 2 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ Kernel Configuration

.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE

:Type: ``str`` (``MAE`` or ``MSE``)
:Type: ``str`` (``MAE``, ``MSE``, ``MAE_FP16``, or ``MSE_FP16``)
:Default: ``MAE``
:Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe.
:Description: Select the error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. ``MAE`` and ``MSE`` compare dequantized candidates in the original input domain. ``MAE_FP16`` and ``MSE_FP16`` compare candidates in the E4M3-scaled domain after the E2M1 x E4M3 product is rounded to FP16.

.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH

Expand Down
24 changes: 20 additions & 4 deletions tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def check_quantization_nvfp4_versus_reference(
@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"])
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"])
@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"])
@pytest.mark.parametrize(
"nvfp4_4over6_err_mode",
["MAE", "MSE", "MAE_FP16", "MSE_FP16"],
ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"],
)
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
Expand Down Expand Up @@ -243,7 +247,11 @@ def test_quantization_block_tiling_versus_reference(
)
@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"])
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"])
@pytest.mark.parametrize(
"nvfp4_4over6_err_mode",
["MAE", "MSE", "MAE_FP16", "MSE_FP16"],
ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"],
)
def test_nvfp4_quantization_extrema_versus_reference(
x_dtype: torch.dtype,
M: int,
Expand Down Expand Up @@ -360,7 +368,11 @@ def test_nvfp4_quantization_extrema_versus_reference(
)
@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"])
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"])
@pytest.mark.parametrize(
"nvfp4_4over6_err_mode",
["MAE", "MSE", "MAE_FP16", "MSE_FP16"],
ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"],
)
def test_nvfp4_quantization_boundary_values(
x_dtype: torch.dtype,
M: int,
Expand Down Expand Up @@ -490,7 +502,11 @@ def test_nvfp4_quantization_boundary_values(
)
@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"])
@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"])
@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"])
@pytest.mark.parametrize(
"nvfp4_4over6_err_mode",
["MAE", "MSE", "MAE_FP16", "MSE_FP16"],
ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"],
)
def test_nvfp4_quantization_noncontiguous_inputs(
x_dtype: torch.dtype,
M: int,
Expand Down
6 changes: 5 additions & 1 deletion tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ def test_quantizer_update(self, module_class):
["none", "weights", "activations", "all"],
ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"],
)
@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"])
@pytest.mark.parametrize(
"nvfp4_4over6_err_mode",
["MAE", "MSE", "MAE_FP16", "MSE_FP16"],
ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"],
)
def test_nvfp4_row_scaled_quantizer_roles(
nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode
):
Expand Down
148 changes: 114 additions & 34 deletions transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,27 @@ namespace nvfp4 {

#if FP4_TYPE_SUPPORTED

#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \
switch (MODE) { \
case kNVTENVFP44Over6MinMAE: { \
constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \
{ __VA_ARGS__ } \
} break; \
case kNVTENVFP44Over6MinMSE: { \
constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \
} \
#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \
switch (MODE) { \
case kNVTENVFP44Over6MinMAE: { \
constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \
{ __VA_ARGS__ } \
} break; \
case kNVTENVFP44Over6MinMSE: { \
constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \
{ __VA_ARGS__ } \
} break; \
case kNVTENVFP44Over6MinMAEFP16: { \
constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAEFP16; \
{ __VA_ARGS__ } \
} break; \
case kNVTENVFP44Over6MinMSEFP16: { \
constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSEFP16; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \
} \
}

#define TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH(E4M3_MAX_VALUE, E4M3_MAX_CONST, ...) \
Expand Down Expand Up @@ -102,13 +110,14 @@ struct ScalePair {
nvfp4_scale_t map6;
float inv_map4;
float inv_map6;
float global_encode_scale;
};

template <NVTENVFP44Over6Mode kMode>
__device__ __forceinline__ float compute_error_rn(const float diff) {
if constexpr (kMode == kNVTENVFP44Over6MinMSE) {
if constexpr (kMode == kNVTENVFP44Over6MinMSE || kMode == kNVTENVFP44Over6MinMSEFP16) {
return __fmul_rn(diff, diff);
} else if constexpr (kMode == kNVTENVFP44Over6MinMAE) {
} else if constexpr (kMode == kNVTENVFP44Over6MinMAE || kMode == kNVTENVFP44Over6MinMAEFP16) {
return fabsf(diff);
} else {
NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode.");
Expand All @@ -118,9 +127,9 @@ __device__ __forceinline__ float compute_error_rn(const float diff) {

template <NVTENVFP44Over6Mode kMode>
__device__ __forceinline__ float compute_error(const float diff) {
if constexpr (kMode == kNVTENVFP44Over6MinMSE) {
if constexpr (kMode == kNVTENVFP44Over6MinMSE || kMode == kNVTENVFP44Over6MinMSEFP16) {
return diff * diff;
} else if constexpr (kMode == kNVTENVFP44Over6MinMAE) {
} else if constexpr (kMode == kNVTENVFP44Over6MinMAE || kMode == kNVTENVFP44Over6MinMAEFP16) {
return fabsf(diff);
} else {
NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode.");
Expand All @@ -147,6 +156,7 @@ __device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax,
fminf(1.0f / (static_cast<float>(scales.map4) * S_dec), detail::TypeExtrema<float>::max);
scales.inv_map6 =
fminf(1.0f / (static_cast<float>(scales.map6) * S_dec), detail::TypeExtrema<float>::max);
scales.global_encode_scale = S_enc;
return scales;
}

Expand Down Expand Up @@ -214,12 +224,67 @@ __device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_
}
}

template <typename Cfg, int E4M3_MAX>
__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8],
const float block_scale_inverse,
__device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) {
return *reinterpret_cast<const uint8_t *>(&sf);
}

__device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only compute 2 values in this function? Also, we don't use half of the scale_h2 here. Why don't we instead try to convert here values from both of the branches (so both 4 and 6 would be there, the scaling factors for both of these branches would be converted in a single instruction). Ideally we would then reuse those scaling factors rather than recasting them for every element in a block - considering we are math bound here, we need to make sure that we eliminate as many instructions as possible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a refactored implementation in 54797b3 but it did not show meaningful speedup:

NVTE_NVFP4_4OVER6=activations \
NVTE_NVFP4_4OVER6_ERR_MODE=<MAE|MSE> \
NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH=1 \
NVTE_NVFP4_DISABLE_RHT=1 \
NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1 \
python3 benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4

Extended Fast-Path Timing Table:

m k n recipe num_gemms baseline_ms old_MAE_fast1_ms refactor_MAE_fast1_ms MAE_refactor_speedup old_MSE_fast1_ms refactor_MSE_fast1_ms MSE_refactor_speedup
16384 7168 2048 nvfp4 4 0.768440 0.979436 0.984238 0.995x 0.978000 0.984061 0.994x
32768 7168 2048 nvfp4 4 1.246045 1.648099 1.643575 1.003x 1.644019 1.645252 0.999x
65536 7168 2048 nvfp4 4 2.226334 2.966817 2.977467 0.996x 2.974658 2.990481 0.995x
98304 7168 2048 nvfp4 4 3.220651 4.395422 4.396148 1.000x 4.401849 4.409521 0.998x
16384 7168 2048 nvfp4 8 0.999235 1.176896 1.182225 0.995x 1.176988 1.258703 0.935x
32768 7168 2048 nvfp4 8 1.428313 1.848563 1.853778 0.997x 1.846523 1.857125 0.994x
65536 7168 2048 nvfp4 8 2.400536 3.148987 3.152150 0.999x 3.150669 3.153690 0.999x
98304 7168 2048 nvfp4 8 3.387845 4.552140 4.549876 1.000x 4.554252 4.560287 0.999x

the refactor-vs-old geomean was:

MAE:      0.9983x
MSE:      0.9889x
combined: 0.9936x

This refactor is essentially common instruction lifting/reuse, and it keeps the same core PTX instructions (cvt.rn.f16x2.e4m3x2, cvt.rn.f16x2.e2m1x2, mul.rn.f16x2) rather than introducing a different PTX operation. I think the compiler can already do this in the old implementation but I am not sure.

@zianglih zianglih Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel level benchmark does not show speedup either:

mode kernel err metric old us refactor us old/refactor
1d nvfp4 - strict 103.260 103.259 1.000x
1d 4over6 MAE strict 799.852 800.892 0.999x
1d 4over6 MAE fast 288.834 289.100 0.999x
1d 4over6 MSE strict 829.294 829.693 1.000x
1d 4over6 MSE fast 287.426 289.148 0.994x
2d nvfp4 - strict 126.692 126.737 1.000x
2d 4over6 MAE strict 847.700 847.070 1.001x
2d 4over6 MAE fast 306.286 306.012 1.001x
2d 4over6 MSE strict 866.949 867.081 1.000x
2d 4over6 MSE fast 306.748 305.842 1.003x

Script: 83e2308 , shape (16384, 6144), with --warmup 20 --iters 2000

@zianglih zianglih Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep the refactoring in beaed67 .The perf behavior of the explicit pattern is more robust to compiler optimizations.

const nvfp4_scale_t sf) {
float2 result;
const uint32_t sf_byte = static_cast<uint32_t>(fp8_bits(sf));
asm volatile(
"{\n"
".reg .b8 byte0, byte1, byte2, byte3;\n"
".reg .b16 fp8_pair;\n"
".reg .b16 scale_h, unused_h;\n"
".reg .b16 lo, hi;\n"
".reg .b32 q_h2;\n"
".reg .b32 scale_h2;\n"
".reg .b32 prod_h2;\n"
"mov.b32 {byte0, byte1, byte2, byte3}, %2;\n"
"cvt.rn.f16x2.e2m1x2 q_h2, byte0;\n"
"cvt.u16.u32 fp8_pair, %3;\n"
"cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n"
"mov.b32 {scale_h, unused_h}, scale_h2;\n"
"mov.b32 scale_h2, {scale_h, scale_h};\n"
"mul.rn.f16x2 prod_h2, q_h2, scale_h2;\n"
"mov.b32 {lo, hi}, prod_h2;\n"
"cvt.f32.f16 %0, lo;\n"
"cvt.f32.f16 %1, hi;\n"
"}"
: "=f"(result.x), "=f"(result.y)
: "r"(e2m1_byte), "r"(sf_byte));
return result;
}

template <typename Cfg>
__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t e2m1_byte,
const float x0, const float x1,
const nvfp4_scale_t sf,
const float global_amax,
const float global_encode_scale,
float *err) {
const float2 candidate = e2m1x2_scaled_e4m3_to_float2(e2m1_byte, sf);
if constexpr (Cfg::err_use_fast_math) {
const float original0 = x0 * global_encode_scale;
const float original1 = x1 * global_encode_scale;
const float diff0 = candidate.x - original0;
const float diff1 = candidate.y - original1;
*err += compute_error<Cfg::mode>(diff0);
*err += compute_error<Cfg::mode>(diff1);
} else {
const float original0 = __fmul_rn(x0, global_encode_scale);
const float original1 = __fmul_rn(x1, global_encode_scale);
const float diff0 = __fsub_rn(candidate.x, original0);
const float diff1 = __fsub_rn(candidate.y, original1);
*err = __fadd_rn(*err, compute_error_rn<Cfg::mode>(diff0));
*err = __fadd_rn(*err, compute_error_rn<Cfg::mode>(diff1));
}
}

template <typename Cfg, int E4M3_MAX>
__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(
const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t sf,
const float global_amax, const float global_encode_scale, float *err) {
uint32_t out = 0;
uint32_t out_dequant_1 = 0;
uint32_t out_dequant_2 = 0;
Expand Down Expand Up @@ -253,15 +318,26 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&
"Try recompiling with sm_XXXa instead of sm_XXX.");
}

const float sf_float = static_cast<float>(sf);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_1, x[0], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_1, x[1], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_2, x[2], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_2, x[3], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_3, x[4], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_3, x[5], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_4, x[6], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_4, x[7], sf_float, global_amax, err);
if constexpr (Cfg::mode == kNVTENVFP44Over6MinMAEFP16 ||
Cfg::mode == kNVTENVFP44Over6MinMSEFP16) {
accumulate_fp16_scaled_error_pair<Cfg>(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err);
accumulate_fp16_scaled_error_pair<Cfg>((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale,
err);
accumulate_fp16_scaled_error_pair<Cfg>((out >> 16) & 0xFFu, x[4], x[5], sf, global_encode_scale,
err);
accumulate_fp16_scaled_error_pair<Cfg>((out >> 24) & 0xFFu, x[6], x[7], sf, global_encode_scale,
err);
} else {
const float sf_float = static_cast<float>(sf);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_1, x[0], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_1, x[1], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_2, x[2], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_2, x[3], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_3, x[4], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_3, x[5], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 0>(out_dequant_4, x[6], sf_float, global_amax, err);
accumulate_dequant_error<Cfg, E4M3_MAX, 16>(out_dequant_4, x[7], sf_float, global_amax, err);
}
return out;
}

Expand All @@ -273,13 +349,17 @@ __device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], c
candidates.map4.err = 0.0f;
candidates.map6.err = 0.0f;
candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error<Cfg, E4M3_MAX>(
x0, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err);
x0, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale,
&candidates.map4.err);
candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error<Cfg, E4M3_MAX>(
x0, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err);
x0, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale,
&candidates.map6.err);
candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error<Cfg, E4M3_MAX>(
x1, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err);
x1, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale,
&candidates.map4.err);
candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error<Cfg, E4M3_MAX>(
x1, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err);
x1, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale,
&candidates.map6.err);
return candidates;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ enum NVTEScalingMode {
* \brief Method for NVFP4 4over6 quantization.
*/
enum NVTENVFP44Over6Mode {
kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */
kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */
kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */
kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */
kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */
kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */
kNVTENVFP44Over6MinMAEFP16 = 3, /*!< Select with lower absolute error in FP16 domain */
kNVTENVFP44Over6MinMSEFP16 = 4, /*!< Select with lower squared error in FP16 domain */
};

/*! \brief TE Tensor type
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

_BACKWARD_OVERRIDES = (None, "high_precision", "dequantized")
_NVFP4_4OVER6_SCOPES = ("none", "weights", "activations", "all")
_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE")
_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE", "MAE_FP16", "MSE_FP16")


class _FormatHelper(NamedTuple):
Expand Down Expand Up @@ -535,7 +535,7 @@ class NVFP4BlockScaling(Recipe):
Select 4over6 tensors that use 256 as the global E4M3 scale
bound. By default, all 4over6 tensors use 256. Use ``'none'``
to keep the standard NVFP4 448 bound for 4over6 tensors.
nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE'
nvfp4_4over6_err_mode : {'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'}, default = 'MAE'
Error metric used by NVFP4 4over6 candidate selection.
backward_override : {None, 'high_precision', 'dequantized'}, default = None
Backward precision mode. None does not modify backward behavior,
Expand Down Expand Up @@ -577,7 +577,7 @@ def __post_init__(self) -> None:
), "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be one of: 'none', 'weights', 'activations', 'all'."
assert (
self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES
), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'."
), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'."

# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
const auto val = *reinterpret_cast<const uint8_t *>(buf);
NVTE_CHECK(val == static_cast<uint8_t>(kNVTENVFP44Over6Disabled) ||
val == static_cast<uint8_t>(kNVTENVFP44Over6MinMAE) ||
val == static_cast<uint8_t>(kNVTENVFP44Over6MinMSE),
val == static_cast<uint8_t>(kNVTENVFP44Over6MinMSE) ||
val == static_cast<uint8_t>(kNVTENVFP44Over6MinMAEFP16) ||
val == static_cast<uint8_t>(kNVTENVFP44Over6MinMSEFP16),
"Invalid NVFP4 4over6 mode (got ", static_cast<int>(val), ")");
config_.nvfp4_4over6_mode = static_cast<NVTENVFP44Over6Mode>(val);
break;
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,10 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize
this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAE;
} else if (nvfp4_4over6_err_mode == "MSE") {
this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSE;
} else if (nvfp4_4over6_err_mode == "MAE_FP16") {
this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAEFP16;
} else if (nvfp4_4over6_err_mode == "MSE_FP16") {
this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSEFP16;
} else {
NVTE_ERROR("Unsupported NVFP4 4over6 error mode: ", nvfp4_4over6_err_mode);
}
Expand Down
Loading