diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index ed5fdd4660..f9d8b45d46 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -9,7 +9,7 @@ """ from __future__ import annotations -from typing import Optional, Tuple, Iterable, Union, List +from typing import Any, Optional, Tuple, Iterable, Union, List import torch import transformer_engine_torch as tex @@ -302,6 +302,7 @@ def quantize( *, out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None, dtype: torch.dtype = None, + amax_reduction_group: Optional[Any] = None, ): """Returns DebugQuantizedTensor object.""" import nvdlfw_inspect.api as debug_api @@ -321,7 +322,9 @@ def quantize( rowwise_gemm_tensor, columnwise_gemm_tensor = None, None if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: - quantized_tensor = self.parent_quantizer(tensor) + quantized_tensor = self.parent_quantizer( + tensor, amax_reduction_group=amax_reduction_group + ) # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. @@ -439,6 +442,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument ) -> QuantizedTensor: """Update quantized tensor - used in weight caching.""" import nvdlfw_inspect.api as debug_api diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 8f42983553..313a7ef857 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2104,14 +2104,6 @@ def forward( "Amax reduction across TP+CP group is necessary when using context parallelism" " with FP8!" ) - if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) - for q in all_quantizers: - if isinstance(q, Float8CurrentScalingQuantizer): - q.with_amax_reduction = True - q.amax_reduction_group = ( - cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group - ) if context_parallel: assert ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 36847e40ed..5685efddae 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1466,6 +1466,9 @@ def forward( cp_group = cp_group[1] cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + # Amax reduction group for current scaling (a2a+p2p vs p2p choice). + fp8_amax_reduction_group = cp_group_a2a if cp_group_a2a is not None else cp_group + ctx.fp8_amax_reduction_group = fp8_amax_reduction_group send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] device_compute_capability = get_device_compute_capability() @@ -1577,7 +1580,12 @@ def forward( # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer + qkv_layout, + q, + k, + v, + QKV_quantizer, + amax_reduction_group=fp8_amax_reduction_group, ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -2127,7 +2135,7 @@ def forward( and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) and not fp8_recipe.mxfp8() ): - out_fp8 = O_quantizer(out_f16) + out_fp8 = O_quantizer(out_f16, amax_reduction_group=fp8_amax_reduction_group) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 ctx.layer_number = layer_number @@ -2265,7 +2273,7 @@ def backward(ctx, dout, *_args): and not isinstance(dout, QuantizedTensorStorage) and not ctx.fp8_recipe.mxfp8() ): - dout = ctx.dO_quantizer(dout) + dout = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group) if ctx.use_fused_attention: dout._data = dout._data.contiguous() elif ctx.use_fused_attention: @@ -2383,7 +2391,7 @@ def backward(ctx, dout, *_args): if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout elif not ctx.fp8_recipe.mxfp8(): - dout_fp8 = ctx.dO_quantizer(dout) + dout_fp8 = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group) if not ctx.fp8_recipe.mxfp8(): dout = dout_fp8._data @@ -2903,7 +2911,14 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv, _, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _, _ = combine_and_quantize( + ctx.qkv_layout, + dq, + dk, + dv, + ctx.dQKV_quantizer, + amax_reduction_group=ctx.fp8_amax_reduction_group, + ) if ctx.fp8: # print quantizers @@ -3055,6 +3070,9 @@ def forward( cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + # Amax reduction group for current scaling (all_gather CP). + fp8_amax_reduction_group = cp_group + ctx.fp8_amax_reduction_group = fp8_amax_reduction_group qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format o_format = qkv_format _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) @@ -3155,7 +3173,12 @@ def forward( fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 if not is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer + qkv_layout, + q, + k, + v, + QKV_quantizer, + amax_reduction_group=fp8_amax_reduction_group, ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -3379,7 +3402,7 @@ def forward( ) ) if (fp8 and is_output_fp8) or bwd_requires_o_fp8: - out_fp8 = O_quantizer(out_f16) + out_fp8 = O_quantizer(out_f16, amax_reduction_group=fp8_amax_reduction_group) out_ret = out_fp8 if is_output_fp8 else out_f16 # save tensors for backward @@ -3532,7 +3555,7 @@ def backward(ctx, dout, *_args): if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout elif not ctx.fp8_recipe.mxfp8(): - dout = ctx.dO_quantizer(dout) + dout = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group) dout_fp8 = dout if not ctx.fp8_recipe.mxfp8(): dout = dout_fp8._data @@ -3842,7 +3865,14 @@ def backward(ctx, dout, *_args): # quantize if necessary if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv, _, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _, _ = combine_and_quantize( + ctx.dqkv_layout, + dq, + dk, + dv, + ctx.dQKV_quantizer, + amax_reduction_group=ctx.fp8_amax_reduction_group, + ) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3918,6 +3948,9 @@ def forward( nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") cp_size = get_distributed_world_size(cp_group) + # Amax reduction group for current scaling (a2a CP). + fp8_amax_reduction_group = cp_group + ctx.fp8_amax_reduction_group = fp8_amax_reduction_group qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format original_qkv_layout = qkv_layout orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape @@ -4023,7 +4056,12 @@ def forward( q_fp8, k_fp8, v_fp8 = q, k, v elif not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( - qkv_layout, q, k, v, QKV_quantizer + qkv_layout, + q, + k, + v, + QKV_quantizer, + amax_reduction_group=fp8_amax_reduction_group, ) if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] @@ -4137,7 +4175,7 @@ def forward( out_f16 = out_ if bwd_requires_o_fp8: if not isinstance(out_, QuantizedTensorStorage): - out_fp8 = O_quantizer(out_) + out_fp8 = O_quantizer(out_, amax_reduction_group=fp8_amax_reduction_group) out_part = out_fp8 if bwd_requires_o_f16: if isinstance(out_, QuantizedTensorStorage): @@ -4212,7 +4250,7 @@ def forward( out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) if is_output_fp8: if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): - out_fp8 = O_quantizer(out_) + out_fp8 = O_quantizer(out_, amax_reduction_group=fp8_amax_reduction_group) out_f16 = out_ else: if fp8_recipe.delayed(): @@ -4358,7 +4396,7 @@ def backward(ctx, dout, *_args): if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout elif not ctx.fp8_recipe.mxfp8(): - dout = ctx.dO_quantizer(dout) + dout = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group) dout_fp8 = dout if not ctx.fp8_recipe.mxfp8(): dout = dout._data @@ -4576,7 +4614,12 @@ def backward(ctx, dout, *_args): ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() ) and ctx.is_input_fp8: dq, dk, dv, _, _ = combine_and_quantize( - ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ctx.dqkv_layout, + dq, + dk, + dv, + ctx.dQKV_quantizer, + amax_reduction_group=ctx.fp8_amax_reduction_group, ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8a07d7af79..201aefcc55 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2783,8 +2783,13 @@ def combine_and_quantize( used_in_forward=True, used_in_backward=False, keep_same_data_and_scale_inv_format=False, + amax_reduction_group=None, ): - """Combine Q, K, V tensors based on qkv_layout and quantize them together.""" + """Combine Q, K, V tensors based on qkv_layout and quantize them together. + + ``amax_reduction_group`` is an optional process group used to all-reduce the + amax during quantization (current scaling, context parallelism). + """ if isinstance(qkv_quantizer, MXFP8Quantizer): qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) assert qkv_format in ("bshd", "sbhd"), ( @@ -2885,7 +2890,7 @@ def combine_and_quantize( case 1: dim = qkv_layout.find("3") qkv = combine_tensors([q, k, v], dim) - qkv_fp8 = qkv_quantizer(qkv) + qkv_fp8 = qkv_quantizer(qkv, amax_reduction_group=amax_reduction_group) q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True) case 2: dim = qkv_layout.split("_")[1].find("2") @@ -2896,7 +2901,7 @@ def combine_and_quantize( numels = [x.numel() for x in tensors] numels = [sum(numels[:i]) for i in range(num_tensors + 1)] qkv = torch.cat([x.view(-1) for x in tensors], dim=0) - qkv_fp8 = qkv_quantizer(qkv) + qkv_fp8 = qkv_quantizer(qkv, amax_reduction_group=amax_reduction_group) q_data, kv_data = [ qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) ] @@ -2908,7 +2913,7 @@ def combine_and_quantize( numels = [x.numel() for x in tensors] numels = [sum(numels[:i]) for i in range(num_tensors + 1)] qkv = torch.cat([x.view(-1) for x in tensors], dim=0) - qkv_fp8 = qkv_quantizer(qkv) + qkv_fp8 = qkv_quantizer(qkv, amax_reduction_group=amax_reduction_group) q_data, k_data, v_data = [ qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) ] diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index d85dcda159..e904a34754 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -77,6 +77,13 @@ std::unique_ptr convert_quantizer(py::handle quantizer) { NVTE_ERROR("Unexpected type for quantizer"); } +c10::intrusive_ptr convert_amax_reduction_group(py::handle amax_reduction_group) { + if (amax_reduction_group.is_none()) { + return {}; + } + return amax_reduction_group.cast>(); +} + transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe) { // if e4m3 or hybrid + forward diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 779b145dd9..5386a7f5c9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -130,9 +130,16 @@ class Quantizer { virtual std::pair convert_and_update_tensor( py::object tensor) const = 0; - /*! @brief Convert to a quantized data format */ + /*! @brief Convert to a quantized data format + * + * The optional ``amax_reduction_group`` is supplied per call (current scaling + * and NVFP4). When set, the per-tensor amax is all-reduced across the group + * before computing the scale. Quantizers that do not support amax reduction + * ignore it. + */ virtual void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) = 0; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) = 0; virtual ~Quantizer() = default; @@ -172,7 +179,8 @@ class NoneQuantizer : public Quantizer { std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) override; }; class Float8Quantizer : public Quantizer { @@ -206,14 +214,13 @@ class Float8Quantizer : public Quantizer { std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) override; }; class Float8CurrentScalingQuantizer : public Quantizer { public: DType dtype; - bool with_amax_reduction; - c10::intrusive_ptr amax_reduction_group; bool force_pow_2_scales = false; float amax_epsilon = 0.0; @@ -245,21 +252,24 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) override; /*! @brief Quantize to FP8, skipping local amax computation * * The provided amax tensor is assumed to already hold the local - * amax. The amax may still be reduced across the amax reduction - * group. + * amax. The amax may still be reduced across ``amax_reduction_group`` + * when one is supplied. */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, at::Tensor amax, - const std::optional& noop_flag = std::nullopt); + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}); private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax, - at::Tensor amax_buf, at::Tensor scale_buf); + at::Tensor amax_buf, at::Tensor scale_buf, + c10::intrusive_ptr amax_reduction_group = {}); }; class Float8BlockQuantizer : public Quantizer { @@ -300,7 +310,8 @@ class Float8BlockQuantizer : public Quantizer { std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -326,16 +337,14 @@ class MXFP8Quantizer : public Quantizer { std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; class NVFP4Quantizer : public Quantizer { public: - // amax reduction for low precision FP4 AG - bool with_amax_reduction; - c10::intrusive_ptr amax_reduction_group; // random hadamard transform bool with_rht; bool with_post_rht_amax; @@ -379,17 +388,20 @@ class NVFP4Quantizer : public Quantizer { std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; + const std::optional& noop_flag = std::nullopt, + c10::intrusive_ptr amax_reduction_group = {}) override; void quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, bool compute_amax); + const std::optional& noop_flag, bool compute_amax, + c10::intrusive_ptr amax_reduction_group = {}); /*! @brief Quantize to NVFP4, skipping local amax computation * * The input tensor's amax pointer is assumed to already hold the - * local amax. The amax may still be reduced across the amax - * reduction group. + * local amax. The amax may still be reduced across + * ``amax_reduction_group`` when one is supplied. */ - void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, + c10::intrusive_ptr amax_reduction_group = {}); std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; @@ -409,6 +421,12 @@ class NVFP4Quantizer : public Quantizer { std::unique_ptr convert_quantizer(py::handle quantizer); +/*! @brief Convert a Python process group handle into a C++ process group. + * + * Returns a null pointer when the handle is ``None`` (i.e. no amax reduction). + */ +c10::intrusive_ptr convert_amax_reduction_group(py::handle amax_reduction_group); + std::vector getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..9822c7af75 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -296,7 +296,8 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, float eps, py::object ln_out, py::handle quantizer, DType out_dtype, const int sm_margin, - const bool zero_centered_gamma); + const bool zero_centered_gamma, + py::handle amax_reduction_group); /*************************************************************************************************** * RMSNorm @@ -313,7 +314,8 @@ std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor & std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object ln_out, py::handle quantizer, DType otype, - const int sm_margin, const bool zero_centered_gamma); + const int sm_margin, const bool zero_centered_gamma, + py::handle amax_reduction_group); /*************************************************************************************************** * Memory allocation @@ -333,7 +335,7 @@ py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector at::ScalarType dtype, at::Device device, bool pin_memory); py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop_flag); + std::optional noop_flag, py::handle amax_reduction_group); py::object nvfp4_quantize_with_amax(const at::Tensor &tensor, py::handle quantizer, const at::Tensor &rowwise_amax, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aab5a87b9a..277305ae9d 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -31,25 +31,17 @@ std::vector get_tensor_shape(const TensorWrapper &tensor) { return std::vector(shape.data, shape.data + shape.ndim); } -void allreduce_nvfp4_amax_tensors(NVFP4Quantizer *nvfp4_quantizer_cpp, - std::vector &&amax_tensors) { - if (!nvfp4_quantizer_cpp->with_amax_reduction || amax_tensors.empty()) { - return; - } - c10d::AllreduceCoalescedOptions opts; - opts.reduceOp = c10d::ReduceOp::MAX; - NVTE_SCOPED_GIL_RELEASE({ - nvfp4_quantizer_cpp->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); - }); -} - } // namespace py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop_flag) { + std::optional noop_flag, py::handle amax_reduction_group) { // Convert quantizer to C++ object auto quantizer_cpp = convert_quantizer(quantizer); + // The amax reduction process group is supplied per call and forwarded to the + // quantize kernel so it can all-reduce the amax (current scaling / NVFP4). + auto amax_reduction_group_cpp = convert_amax_reduction_group(amax_reduction_group); + // Convert input tensor to C++ object auto input_contiguous = tensor.contiguous(); auto input_cpp = makeTransformerEngineTensor(input_contiguous); @@ -72,7 +64,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob } // Perform quantization - quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp, + std::move(amax_reduction_group_cpp)); // Post-quantize swizzle for quantizers whose kernel does not bake // the GEMM-swizzled scale layout in directly @@ -337,15 +330,6 @@ py::object nvfp4_group_quantize_with_amax(const at::Tensor &tensor, py::handle q grouped_output_py.attr("columnwise_amax") = py::cast(columnwise_amax); } - std::vector amax_tensors; - if (grouped_output_tensor_cpp.get_amax().data_ptr != nullptr) { - amax_tensors.push_back(rowwise_amax); - } - if (grouped_output_tensor_cpp.get_columnwise_amax().data_ptr != nullptr) { - amax_tensors.push_back(columnwise_amax); - } - allreduce_nvfp4_amax_tensors(nvfp4_quantizer_cpp, std::move(amax_tensors)); - if (empty_input_buffer) { return py::reinterpret_borrow(grouped_output_py); } @@ -1474,8 +1458,6 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, const auto &quantizer = *quantizers.front(); NVTE_CHECK(!quantizer.with_2d_quantization, "NVFP4 split-quantize does not support 2D quantization"); - NVTE_CHECK(!quantizer.with_amax_reduction, - "NVFP4 split-quantize does not support amax reduction"); if (quantizer.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled) { NVTE_CHECK(!quantizer.with_rht, "NVFP4 4over6 quantization does not support RHT."); NVTE_CHECK(!quantizer.stochastic_rounding, diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index c3dec944e4..90201eaad0 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -61,7 +61,8 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, - const bool zero_centered_gamma) { + const bool zero_centered_gamma, + py::handle amax_reduction_group) { using namespace transformer_engine::pytorch::detail; // Ensure that cuDNN handle is created on the correct device, @@ -90,6 +91,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantizer auto quantizer_cpp = convert_quantizer(quantizer); + auto amax_reduction_group_cpp = convert_amax_reduction_group(amax_reduction_group); // Choose implementation enum class Impl { @@ -196,15 +198,18 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if needed switch (impl) { case Impl::UNFUSED: { - quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte, std::nullopt, + amax_reduction_group_cpp); } break; case Impl::FUSED_NORM_AMAX_FP8: { auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, amax_buf); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, amax_buf, std::nullopt, + amax_reduction_group_cpp); } break; case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, + amax_reduction_group_cpp); } break; default: { } @@ -304,7 +309,8 @@ std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor & std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object out, py::handle quantizer, DType out_dtype, - const int sm_margin, const bool zero_centered_gamma) { + const int sm_margin, const bool zero_centered_gamma, + py::handle amax_reduction_group) { using namespace transformer_engine::pytorch::detail; // Ensure that cuDNN handle is created on the correct device, @@ -327,6 +333,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantizer auto quantizer_cpp = convert_quantizer(quantizer); + auto amax_reduction_group_cpp = convert_amax_reduction_group(amax_reduction_group); // Choose implementation enum class Impl { @@ -431,15 +438,18 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if needed switch (impl) { case Impl::UNFUSED: { - quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte, std::nullopt, + amax_reduction_group_cpp); } break; case Impl::FUSED_NORM_AMAX_FP8: { auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, amax_buf); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, amax_buf, std::nullopt, + amax_reduction_group_cpp); } break; case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, + amax_reduction_group_cpp); } break; default: { } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..79870e3dee 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -194,7 +194,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::is_method(dtype_class)); m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), - py::arg("output") = py::none(), py::arg("noop") = py::none()); + py::arg("output") = py::none(), py::arg("noop") = py::none(), + py::arg("amax_reduction_group") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); m.def("create_empty_quantized_tensor", @@ -331,11 +332,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Other granular functions m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); + py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), + py::arg("amax_reduction_group") = py::none()); m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm"); m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); + py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), + py::arg("amax_reduction_group") = py::none()); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, "Fused backward of RMSNorm + add"); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 5fc50953a1..9ab75c972f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -256,7 +256,8 @@ std::pair NoneQuantizer::convert_and_update_tensor( } void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { + const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { NVTE_ERROR("NoneQuantizer does not support quantization"); } @@ -554,7 +555,8 @@ std::pair Float8Quantizer::convert_and_update_tensor( } void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { + const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { if (input.numel() == 0) { return; } @@ -571,18 +573,6 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); - // Get amax reduction group if needed - const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); - c10::intrusive_ptr amax_reduction_group; - if (with_amax_reduction) { - auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); - NVTE_CHECK(!group.is_none(), - "Float8CurrentScalingQuantizer could not canonicalize amax reduction group"); - amax_reduction_group = group.cast>(); - } - this->with_amax_reduction = with_amax_reduction; - this->amax_reduction_group = amax_reduction_group; - // fp8 current scaling specific quantization params this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); @@ -883,10 +873,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ return {std::move(out_cpp), std::move(tensor)}; } -void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, - bool compute_amax, at::Tensor amax_buf, - at::Tensor scale_buf) { +void Float8CurrentScalingQuantizer::quantize_impl( + const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, + bool compute_amax, at::Tensor amax_buf, at::Tensor scale_buf, + c10::intrusive_ptr amax_reduction_group) { out.set_amax(amax_buf.data_ptr(), DType::kFloat32, std::vector{1}); out.set_scale(scale_buf.data_ptr(), DType::kFloat32, std::vector{1}); @@ -916,8 +906,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); } - // Perform amax reduction if needed - if (with_amax_reduction) { + // Perform amax reduction if a process group was supplied for this call + if (amax_reduction_group) { // allreduce amax tensor c10d::AllreduceOptions opts; opts.reduceOp = c10d::ReduceOp::MAX; @@ -937,19 +927,23 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te out.set_scale(nullptr, DType::kFloat32, out.defaultShape); } -void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { +void Float8CurrentScalingQuantizer::quantize( + const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); at::Tensor amax_and_scale = at::empty({2}, opts); - this->quantize_impl(input, out, noop_flag, true, amax_and_scale[0], amax_and_scale[1]); + this->quantize_impl(input, out, noop_flag, true, amax_and_scale[0], amax_and_scale[1], + std::move(amax_reduction_group)); } void Float8CurrentScalingQuantizer::quantize_with_amax( TensorWrapper& input, TensorWrapper& out, at::Tensor amax, - const std::optional& noop_flag) { + const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); input.set_amax(nullptr, DType::kFloat32, input.defaultShape); - this->quantize_impl(input, out, noop_flag, false, std::move(amax), at::empty({1}, opts)); + this->quantize_impl(input, out, noop_flag, false, std::move(amax), at::empty({1}, opts), + std::move(amax_reduction_group)); } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -1305,7 +1299,8 @@ std::pair Float8BlockQuantizer::convert_and_update_te } void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { + const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { if (input.numel() == 0) { return; } @@ -1694,7 +1689,8 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( } void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { + const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { if (input.numel() == 0) { return; } @@ -1760,17 +1756,6 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize } this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); - // Get amax reduction group if needed for NVFP4 AG - const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); - c10::intrusive_ptr amax_reduction_group; - if (with_amax_reduction) { - auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); - NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group"); - amax_reduction_group = group.cast>(); - } - this->with_amax_reduction = with_amax_reduction; - this->amax_reduction_group = amax_reduction_group; - this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast(); this->rht_matrix = quantizer.attr("rht_matrix").cast(); } @@ -2337,10 +2322,10 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( } void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, - bool compute_amax) { + const std::optional& noop_flag, bool compute_amax, + c10::intrusive_ptr amax_reduction_group) { auto reduce_amaxes = [&]() { - if (!this->with_amax_reduction) { + if (!amax_reduction_group) { return; } @@ -2365,7 +2350,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou c10d::AllreduceCoalescedOptions opts; opts.reduceOp = c10d::ReduceOp::MAX; NVTE_SCOPED_GIL_RELEASE( - { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); + { amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); }; // Nothing to be done if input is empty @@ -2406,7 +2391,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(!this->stochastic_rounding, "Row-scaled NVFP4 quantization does not support stochastic rounding."); - NVTE_CHECK(!this->with_amax_reduction, + NVTE_CHECK(!amax_reduction_group, "Row-scaled NVFP4 quantization does not support amax reduction."); NVTE_CHECK(cols % 16 == 0, "Row-scaled NVFP4 quantization requires last dim divisible by 16."); } @@ -2565,11 +2550,13 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - this->quantize_impl(input, out, noop_flag, true); + const std::optional& noop_flag, + c10::intrusive_ptr amax_reduction_group) { + this->quantize_impl(input, out, noop_flag, true, std::move(amax_reduction_group)); } -void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { +void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out, + c10::intrusive_ptr amax_reduction_group) { NVTE_CHECK(!out.get_row_scaled_nvfp4(), "quantize_with_amax is not supported for row-scaled NVFP4 quantization."); // Update output tensor amaxes with input tensor amax @@ -2592,7 +2579,7 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out input.set_amax(nullptr, DType::kFloat32, input.defaultShape); // Perform quantization - this->quantize_impl(input, out, std::nullopt, false); + this->quantize_impl(input, out, std::nullopt, false, std::move(amax_reduction_group)); } std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py index ecbb667ecf..a248827180 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py @@ -6,7 +6,7 @@ import dataclasses import math -from typing import Optional, Tuple, Iterable +from typing import Any, Optional, Tuple, Iterable import torch @@ -225,6 +225,18 @@ def __getstate__(self): state["amax_reduction_group"] = None return state + def _resolve_amax_reduction_group(self, amax_reduction_group): + """Resolve the effective amax reduction group. + + The group passed in (per-call arg or destination tensor) takes precedence; + otherwise fall back to the deprecated group stored on the quantizer. + """ + if amax_reduction_group is not None: + return amax_reduction_group + if self.with_amax_reduction: + return self.amax_reduction_group + return None + @property def custom(self) -> bool: """Flag to indicate this quantizer is custom.""" @@ -260,7 +272,12 @@ def compute_scale( pow_2_scales=pow_2_scales, ) - def _quantize(self, tensor: torch.Tensor) -> Tuple[ + def _quantize( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, + ) -> Tuple[ Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], @@ -283,12 +300,8 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ - qx_t: quantized data in column-major order (if columnwise_usage), None otherwise - sx_t: empty scale tensor for qx_t (if columnwise_usage), None otherwise """ - # Handle amax reduction if enabled - if self.with_amax_reduction: - assert ( - self.amax_reduction_group is not None - ), "amax_reduction_group must be set when with_amax_reduction is True" - + # Handle amax reduction if a process group is supplied + if amax_reduction_group is not None: # Compute local amax if tensor.numel() == 0: amax = torch.empty(1, dtype=torch.float32, device=tensor.device) @@ -297,7 +310,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ # Reduce amax across all ranks torch.distributed.all_reduce( - amax, group=self.amax_reduction_group, op=torch.distributed.ReduceOp.MAX + amax, group=amax_reduction_group, op=torch.distributed.ReduceOp.MAX ) # Compute scale using the global amax @@ -347,7 +360,10 @@ def quantize( if tensor.ndim > 2: tensor = tensor.view(-1, tensor.shape[-1]) - qx, sx, qx_t, sx_t = self._quantize(tensor) + amax_reduction_group = self._resolve_amax_reduction_group( + kwargs.get("amax_reduction_group") + ) + qx, sx, qx_t, sx_t = self._quantize(tensor, amax_reduction_group=amax_reduction_group) return CurrentScalingTensorRef( data=qx, @@ -449,6 +465,7 @@ def update_quantized( dst: QuantizedTensorStorage, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, ) -> QuantizedTensorStorage: """Update the quantized tensor with the given tensor in-place @@ -460,6 +477,10 @@ def update_quantized( Destination ExperimentalQuantizedTensor to update noop_flag: torch.Tensor, optional float32 flag indicating whether to avoid performing update + amax_reduction_group: optional + Process group for the amax all-reduce. Falls back to ``dst.amax_reduction_group`` + and then to the deprecated group stored on the quantizer; ``None`` means + no reduction. """ # Handle noop flag if noop_flag is not None and noop_flag.item() != 0: @@ -474,7 +495,10 @@ def update_quantized( if src.ndim > 2: src = src.view(-1, src.shape[-1]) - qx, sx, qx_t, sx_t = self._quantize(src) + if amax_reduction_group is None: + amax_reduction_group = getattr(dst, "amax_reduction_group", None) + amax_reduction_group = self._resolve_amax_reduction_group(amax_reduction_group) + qx, sx, qx_t, sx_t = self._quantize(src, amax_reduction_group=amax_reduction_group) # Update the destination with new data dst.data = qx diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 670eecaa5e..068639ea2a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -986,6 +986,7 @@ def _all_gather_fp8( async_op: bool = False, quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, + amax_reduction_group: Optional[dist_group_type] = None, ) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]: """All-gather FP8 tensor along first dimension.""" world_size = get_distributed_world_size(process_group) @@ -1016,7 +1017,7 @@ def _all_gather_fp8( init_rowwise_usage = quantizer.rowwise_usage init_columnwise_usage = quantizer.columnwise_usage quantizer.set_usage(rowwise=True, columnwise=False) - inp = quantizer(inp) + inp = quantizer(inp, amax_reduction_group=amax_reduction_group) quantizer.set_usage( rowwise=init_rowwise_usage, columnwise=init_columnwise_usage, @@ -1328,8 +1329,15 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, + amax_reduction_group: Optional[dist_group_type] = None, ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: - """All-gather NVFP4 tensor along first dimension.""" + """All-gather NVFP4 tensor along first dimension. + + ``amax_reduction_group`` is an optional process group used to all-reduce the + amax during the local quantization so that all shards share the same global + scale (current scaling / sequence parallelism). It is forwarded to the + quantizer per call. + """ # Input tensor attributes in_shape: Iterable[int] = None @@ -1388,12 +1396,12 @@ def _all_gather_nvfp4( memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) - out = quantizer(out) + out = quantizer(out, amax_reduction_group=amax_reduction_group) return out, None # Cast input tensor to NVFP4 with required data if not isinstance(inp, NVFP4TensorStorage): - inp = quantizer(inp) + inp = quantizer(inp, amax_reduction_group=amax_reduction_group) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None ): @@ -1401,7 +1409,7 @@ def _all_gather_nvfp4( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to NVFP4." ) - inp = quantizer(inp.dequantize(dtype=dtype)) + inp = quantizer(inp.dequantize(dtype=dtype), amax_reduction_group=amax_reduction_group) # Construct NVFP4 output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) @@ -1642,16 +1650,21 @@ def gather_along_first_dim( process_group: dist_group_type, async_op: bool = False, quantizer: Optional[Quantizer] = None, + amax_reduction_group: Optional[dist_group_type] = None, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather tensors and concatenate along first dimension. + + ``amax_reduction_group`` is an optional process group used to all-reduce the + amax during the local quantization (current scaling). It is forwarded to the + quantizer per call. """ # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: if quantizer is not None and not isinstance(inp, QuantizedTensorStorage): - inp = quantizer(inp) + inp = quantizer(inp, amax_reduction_group=amax_reduction_group) return inp, None # Debug case - call gather_along_first_dim on each tensor @@ -1706,6 +1719,7 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + amax_reduction_group=amax_reduction_group, ) # FP8 block scaling case, block length = 128 @@ -1746,6 +1760,7 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + amax_reduction_group=amax_reduction_group, ) # High-precision communication for quantized tensors diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index bf5a230e84..6aae45aea7 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -42,8 +42,13 @@ def apply_normalization( normalization: str, fwd_ln_sm_margin: int, zero_centered_gamma: bool, + amax_reduction_group=None, ): - """Apply normalization to input.""" + """Apply normalization to input. + + ``amax_reduction_group`` is an optional distributed process group used to + all-reduce the amax in the fused norm+quantize path (current scaling). + """ normalization_func = _get_normalization_func(normalization, True) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) @@ -56,6 +61,7 @@ def apply_normalization( TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, fwd_ln_sm_margin, zero_centered_gamma, + amax_reduction_group, ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..dfdd53da24 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1575,6 +1575,10 @@ def grad_output_preprocess( grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel use_fp8_bwd = ctx.fp8 and ctx.backward_override is None + # Amax reduction group for current scaling: the grad output amax is + # reduced across the tensor-parallel group when it is gathered for a + # row-parallel layer with sequence parallelism. + grad_output_amax_reduction_group = ctx.tp_group if gather_grad_output else None # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: @@ -1607,7 +1611,9 @@ def grad_output_preprocess( Float8BlockwiseQTensorStorage, ), ): - grad_output = quantizer(grad_output) + grad_output = quantizer( + grad_output, amax_reduction_group=grad_output_amax_reduction_group + ) # Copy into communication buffer, and replace original gradient with it grad_output, _ = fill_userbuffers_buffer_for_all_gather( @@ -1621,6 +1627,7 @@ def grad_output_preprocess( grad_output, ctx.tp_group, quantizer=quantizer, + amax_reduction_group=grad_output_amax_reduction_group, ) return grad_output, grad_bias @@ -1717,6 +1724,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False + amax_reduction_group = None if is_dtensor and isinstance( quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer) ): @@ -1726,10 +1734,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if device_mesh.ndim > 1 else device_mesh.get_group() ) - quantizer.amax_reduction_group = amax_reduction_group - quantizer.with_amax_reduction = True - # Quantize parameter - param = quantizer(param) + # Quantize parameter. The amax reduction group (Float8 current + # scaling and NVFP4) is supplied per call. + param = quantizer(param, amax_reduction_group=amax_reduction_group) # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..bf9bc9eb27 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -157,6 +157,9 @@ def forward( nvtx_label = f"{nvtx_label}.{ub_name}" with_input_all_gather = parallel_mode == "column" and sequence_parallel + # Amax reduction group for the input/norm-output quantizer (current + # scaling + column-parallel sequence parallel). + input_amax_reduction_group = tp_group if with_input_all_gather else None # Make sure input dimensions are compatible out_features, in_features = weight.shape @@ -241,6 +244,7 @@ def forward( normalization, fwd_ln_sm_margin, zero_centered_gamma, + amax_reduction_group=input_amax_reduction_group, ) nvtx_range_pop(f"{nvtx_label}.norm") @@ -263,16 +267,20 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8 or debug: - ln_out = input_quantizer(ln_out) + ln_out = input_quantizer( + ln_out, amax_reduction_group=input_amax_reduction_group + ) input_quantizer.set_usage(rowwise=True, columnwise=False) - ln_out_total = input_quantizer(ln_out_total) + ln_out_total = input_quantizer( + ln_out_total, amax_reduction_group=input_amax_reduction_group + ) else: quantizer = None if fp8 or debug: quantizer = input_quantizer # custom recipe doesn't need to support quantized AG if not with_quantized_norm and not custom: - ln_out = quantizer(ln_out) + ln_out = quantizer(ln_out, amax_reduction_group=input_amax_reduction_group) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( @@ -286,10 +294,11 @@ def forward( ln_out, tp_group, quantizer=quantizer, + amax_reduction_group=input_amax_reduction_group, ) else: if (fp8 or debug) and not with_quantized_norm: - ln_out = input_quantizer(ln_out) + ln_out = input_quantizer(ln_out, amax_reduction_group=input_amax_reduction_group) ln_out_total = ln_out nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # ------------------------------------------------------ @@ -679,6 +688,14 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication # -------------------------------------------------- + # Amax reduction groups for current scaling. + input_amax_reduction_group = ( + ctx.tp_group if (ctx.sequence_parallel and ctx.parallel_mode == "column") else None + ) + grad_output_amax_reduction_group = ( + ctx.tp_group if (ctx.sequence_parallel and ctx.parallel_mode == "row") else None + ) + # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage @@ -746,6 +763,7 @@ def backward( ctx.tp_group, async_op=True, quantizer=quantizer, + amax_reduction_group=input_amax_reduction_group, ) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: @@ -947,14 +965,18 @@ def backward( ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.input_quantizer(ln_out_total) + ln_out_total = ctx.input_quantizer( + ln_out_total, amax_reduction_group=input_amax_reduction_group + ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output) + grad_output = ctx.grad_output_quantizer( + grad_output, amax_reduction_group=grad_output_amax_reduction_group + ) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD @@ -1552,8 +1574,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) def get_quantizer_roles( self, @@ -1919,54 +1939,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_fwd"][ FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "column": - # set input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group else: - # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) + # set grad_output_quantizer with amax epsilon and power_2_scale self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + layernorm_linear.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # set input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c6cca74ef..754018e89f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -415,6 +415,12 @@ def _forward( and not custom ) + # Amax reduction group for the fc1 input/norm-output quantizer (current + # scaling, column-parallel + sequence parallel). + fc1_input_amax_reduction_group = ( + tp_group if (sequence_parallel and set_parallel_mode) else None + ) + # Apply normalization ln_out, mu, rsigma = apply_normalization( inputmat, @@ -427,6 +433,7 @@ def _forward( normalization, fwd_ln_sm_margin, zero_centered_gamma, + amax_reduction_group=fc1_input_amax_reduction_group, ) ln_out_return = None @@ -447,16 +454,22 @@ def _forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8 or debug: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer( + ln_out, amax_reduction_group=fc1_input_amax_reduction_group + ) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) - ln_out_total = fc1_input_quantizer(ln_out_total) + ln_out_total = fc1_input_quantizer( + ln_out_total, amax_reduction_group=fc1_input_amax_reduction_group + ) else: quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer # custom recipe doesn't need to support quantized AG if not with_quantized_norm and not custom: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer( + ln_out, amax_reduction_group=fc1_input_amax_reduction_group + ) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: # Copy into Userbuffers buffer @@ -473,10 +486,13 @@ def _forward( ln_out, tp_group, quantizer=quantizer, + amax_reduction_group=fc1_input_amax_reduction_group, ) else: if (fp8 or debug) and not with_quantized_norm: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer( + ln_out, amax_reduction_group=fc1_input_amax_reduction_group + ) ln_out_total = ln_out # Cast weights to expected dtype @@ -1127,6 +1143,16 @@ def backward( ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad + # Amax reduction groups for current scaling. The fc1 input + # (column-parallel) and fc2 grad output (row-parallel) amaxes are + # reduced across the tensor-parallel group with sequence parallelism. + fc1_input_amax_reduction_group = ( + ctx.tp_group if (ctx.sequence_parallel and ctx.set_parallel_mode) else None + ) + fc2_grad_output_amax_reduction_group = ( + ctx.tp_group if (ctx.sequence_parallel and ctx.set_parallel_mode) else None + ) + # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage @@ -1180,6 +1206,7 @@ def backward( ctx.tp_group, async_op=True, quantizer=quantizer, + amax_reduction_group=fc1_input_amax_reduction_group, ) else: ln_out_total = ln_out @@ -1347,7 +1374,10 @@ def backward( grad_output.update_usage(columnwise_usage=True) else: ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.fc2_grad_output_quantizer(grad_output) + grad_output = ctx.fc2_grad_output_quantizer( + grad_output, + amax_reduction_group=fc2_grad_output_amax_reduction_group, + ) # Whether to set grad arg in general_gemm grad_arg = True @@ -1604,7 +1634,9 @@ def fc2_wgrad_gemm( ln_out_total.update_usage(columnwise_usage=True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.fc1_input_quantizer(ln_out_total) + ln_out_total = ctx.fc1_input_quantizer( + ln_out_total, amax_reduction_group=fc1_input_amax_reduction_group + ) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and @@ -2165,8 +2197,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) def get_quantizer_roles( self, @@ -2676,15 +2706,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_fwd"][ FP8FwdTensorIdx.GEMM2_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # parallel related - if self.sequence_parallel and self.set_parallel_mode: - # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group else: # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ @@ -2700,36 +2721,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - if self.sequence_parallel and self.set_parallel_mode: - # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + layernorm_mlp.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.set_parallel_mode: - # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.set_parallel_mode: - # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT2 - ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..e7e80a37ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -301,6 +301,12 @@ def _linear_forward_impl( with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) + # Amax reduction process group for the input quantizer (current scaling + + # column-parallel sequence parallel). Quantizers that do not support amax + # reduction ignore it. + input_amax_reduction_group = ( + tp_group if (sequence_parallel and parallel_mode == "column") else None + ) # Configure Userbuffers communication (comm+GEMM overlap) ub_obj = None @@ -347,7 +353,9 @@ def _linear_forward_impl( # tensor will not be cached for backward pass input_quantizer.set_usage(columnwise=False) own_quantized_input = False - inputmat = input_quantizer(inputmat) + inputmat = input_quantizer( + inputmat, amax_reduction_group=input_amax_reduction_group + ) else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -385,7 +393,9 @@ def _linear_forward_impl( and backward_override is None ), ) - inputmat = input_quantizer(inputmat) + inputmat = input_quantizer( + inputmat, amax_reduction_group=input_amax_reduction_group + ) own_quantized_input = True else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -743,6 +753,20 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_weight_quantizer = args.grad_weight_quantizer grad_output_quantizer = args.grad_output_quantizer + # Amax reduction process groups for current scaling. The input amax is + # reduced for a column-parallel layer and the grad output amax for a + # row-parallel layer (both with sequence parallelism). + input_amax_reduction_group = ( + bwd_args.tp_group + if (bwd_args.sequence_parallel and bwd_args.parallel_mode == "column") + else None + ) + grad_output_amax_reduction_group = ( + bwd_args.tp_group + if (bwd_args.sequence_parallel and bwd_args.parallel_mode == "row") + else None + ) + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.backward" if bwd_args.ub_name is not None: @@ -888,7 +912,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. ) else: quantizer.set_usage(rowwise=False, columnwise=True) - inputmat = quantizer(inputmat) + inputmat = quantizer(inputmat, amax_reduction_group=input_amax_reduction_group) else: if isinstance(inputmat, QuantizedTensorStorage): inputmat = inputmat.dequantize(dtype=bwd_args.activation_dtype) @@ -1088,7 +1112,9 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. inputmat_total.update_usage(columnwise_usage=True) else: input_quantizer.set_usage(rowwise=False, columnwise=True) - inputmat_total = input_quantizer(inputmat_total) + inputmat_total = input_quantizer( + inputmat_total, amax_reduction_group=input_amax_reduction_group + ) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and @@ -1130,7 +1156,9 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_output.update_usage(columnwise_usage=True) else: grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = grad_output_quantizer(grad_output) + grad_output = grad_output_quantizer( + grad_output, amax_reduction_group=grad_output_amax_reduction_group + ) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD @@ -1746,8 +1774,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.nvfp4(): - self._customize_quantizers_nvfp4(fwd, recipe) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -2111,15 +2137,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_fwd"][ FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon - # paralle related - if self.sequence_parallel and self.parallel_mode == "column": - # customize input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale self.quantizers["scaling_bwd"][ @@ -2128,37 +2145,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self.quantizers["scaling_bwd"][ FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # parallel related - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group - - def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on current scaling recipe + linear.""" - assert recipe.nvfp4(), "Incorrect recipe." - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # customize input_quantizer with amax reduction TP group - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].with_amax_reduction = True - self.quantizers["scaling_fwd"][ - FP8FwdTensorIdx.GEMM1_INPUT - ].amax_reduction_group = self.tp_group - else: - if self.sequence_parallel and self.parallel_mode == "row": - # customize grad_output_quantizer with amax reduction TP group - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].with_amax_reduction = True - self.quantizers["scaling_bwd"][ - FP8BwdTensorIdx.GRAD_OUTPUT1 - ].amax_reduction_group = self.tp_group def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 6b17d66fcd..4604d169a2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -401,23 +401,6 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon - if getattr(self, "sequence_parallel", False): - tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) - if tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - elif tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group - if recipe.nvfp4(): - if getattr(self, "sequence_parallel", False): - tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) - if tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - elif tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Update quantizer in quantized weight tensor if weight_quantizer is not None and is_quantized_tensor(weight): @@ -537,6 +520,9 @@ def _functional_forward( x = None x_async = None with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + # Amax reduction group for the input quantizer (current scaling + + # column-parallel sequence parallel), supplied per quantize call. + input_amax_reduction_group = tensor_parallel_group if with_x_all_gather else None if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") @@ -551,10 +537,13 @@ def _functional_forward( tensor_parallel_group, async_op=True, quantizer=input_quantizer, + amax_reduction_group=input_amax_reduction_group, ) else: if not is_quantized_tensor(x_local): - x_local = input_quantizer(x_local) + x_local = input_quantizer( + x_local, amax_reduction_group=input_amax_reduction_group + ) x = x_local else: x_local = maybe_dequantize(x_local, dtype) @@ -781,6 +770,9 @@ def _functional_backward( dy = None dy_async = None with_dy_all_gather = tensor_parallel_mode == "row" and sequence_parallel + # Amax reduction group for the grad output quantizer (current scaling + + # row-parallel sequence parallel), supplied per quantize call. + grad_output_amax_reduction_group = tensor_parallel_group if with_dy_all_gather else None if with_quantized_compute: if grad_output_quantizer is None: raise ValueError("Missing quantizer for grad output tensor") @@ -794,10 +786,13 @@ def _functional_backward( tensor_parallel_group, async_op=True, quantizer=grad_output_quantizer, + amax_reduction_group=grad_output_amax_reduction_group, ) else: if not is_quantized_tensor(dy_local): - dy_local = grad_output_quantizer(dy_local) + dy_local = grad_output_quantizer( + dy_local, amax_reduction_group=grad_output_amax_reduction_group + ) else: dy_local.update_usage( rowwise_usage=input_requires_grad, @@ -824,6 +819,7 @@ def _functional_backward( raise ValueError("Input tensor is required to compute weight grad") x_local = input with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + input_amax_reduction_group = tensor_parallel_group if with_x_all_gather else None if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") @@ -834,12 +830,15 @@ def _functional_backward( tensor_parallel_group, async_op=True, quantizer=input_quantizer, + amax_reduction_group=input_amax_reduction_group, ) else: if is_quantized_tensor(x_local): x_local.update_usage(columnwise_usage=True) else: - x_local = input_quantizer(x_local) + x_local = input_quantizer( + x_local, amax_reduction_group=input_amax_reduction_group + ) x = x_local else: x_local = maybe_dequantize(x_local, dtype) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 404796fd63..f22484c963 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -271,6 +271,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument ) -> QuantizedTensor: """Quantize tensor in-place""" raise NotImplementedError( @@ -283,15 +284,27 @@ def quantize( *, out: Optional[QuantizedTensor] = None, dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override + amax_reduction_group: Optional[Any] = None, ) -> QuantizedTensor: - """Quantize tensor""" + """Quantize tensor + + The optional ``amax_reduction_group`` is a distributed process + group used to all-reduce the amax (max) before computing the + scale. Quantizers that do not support amax reduction ignore it. + + """ if out is not None: - return self.update_quantized(tensor, out) + return self.update_quantized(tensor, out, amax_reduction_group=amax_reduction_group) if (not self.internal) and torch.is_grad_enabled(): - return _QuantizeFunc.apply(tensor, self.quantize_impl) - return _QuantizeFunc.forward(None, tensor, self.quantize_impl) + return _QuantizeFunc.apply(tensor, self.quantize_impl, amax_reduction_group) + return _QuantizeFunc.forward(None, tensor, self.quantize_impl, amax_reduction_group) - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument + ) -> QuantizedTensor: """Quantize tensor implementation""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement quantize_impl function" @@ -304,9 +317,16 @@ def multi_quantize(self, list_of_tensors): list_of_output_tensors.append(self.quantize(tensor)) return list_of_output_tensors - def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: + def __call__( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, + ) -> QuantizedTensor: """Quantize tensor""" - return self.quantize(tensor) + if amax_reduction_group is None: + return self.quantize(tensor) + return self.quantize(tensor, amax_reduction_group=amax_reduction_group) def make_empty( self, diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index 56cf503630..df9ba47da5 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -24,9 +24,10 @@ def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused tensor: torch.Tensor, quantize_impl: Callable, + amax_reduction_group: Optional[Any] = None, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return quantize_impl(tensor) + return quantize_impl(tensor, amax_reduction_group=amax_reduction_group) @staticmethod def backward( @@ -35,7 +36,7 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision - return grad, None + return grad, None, None class _IdentityFunc(torch.autograd.Function): diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ba46508d74..ae6507da24 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -75,6 +75,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument ) -> QuantizedTensor: """Update the quantized tensor with data from the source tensor. @@ -114,7 +115,12 @@ def update_quantized( dst._fp8_dtype = self.dtype return dst - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument + ) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 4de8d82217..2b89b0628b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -89,6 +89,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument ) -> QuantizedTensor: if not isinstance(dst, Float8Tensor): raise ValueError("Float8Quantizer can only update Float8Tensor") @@ -107,7 +108,12 @@ def update_quantized( return dst - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument + ) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) @@ -201,16 +207,21 @@ class Float8CurrentScalingQuantizer(Quantizer): parameters are accepted but unused. They are kept for backward compatibility with existing callers. + Note: ``with_amax_reduction`` and ``amax_reduction_group`` are deprecated. + Prefer supplying the amax reduction process group per call to :meth:`quantize`, + or setting it on the destination tensor for in-place :meth:`update_quantized` / + :meth:`quantize_`. A group stored on the quantizer is still honored as a fallback. + """ """FP8 datatype""" dtype: DType - """amax reduction options""" - with_amax_reduction: bool - amax_reduction_group: Optional[dist_group_type] """Options about how to quantize the tensor""" force_pow_2_scales: bool amax_epsilon: float + """Deprecated amax reduction state (kept only for backward compatibility)""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] def __init__( self, @@ -220,12 +231,12 @@ def __init__( rowwise: bool = True, columnwise: bool = True, use_existing_amax: bool = False, - with_amax_reduction: bool = False, - amax_reduction_group: Optional[dist_group_type] = None, force_pow_2_scales: bool = False, amax_epsilon: float = 0.0, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[Any] = None, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) if use_existing_amax or scale is not None or amax is not None: @@ -235,15 +246,38 @@ def __init__( DeprecationWarning, stacklevel=2, ) + if with_amax_reduction or amax_reduction_group is not None: + warnings.warn( + "Storing `with_amax_reduction`/`amax_reduction_group` on " + "Float8CurrentScalingQuantizer is deprecated; pass the reduction group " + "per call to `quantize`, or set it on the destination tensor for in-place " + "`update_quantized`/`quantize_`. The stored group is still honored as a " + "fallback for now, but this support will be removed.", + DeprecationWarning, + stacklevel=2, + ) del device, use_existing_amax, scale, amax # Kept for backward compatibility self.dtype = DType.cast(fp8_dtype) - self.with_amax_reduction = with_amax_reduction - self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group - def __getstate__(self): - """Exclude unpicklable process group from serialized state.""" + def _resolve_amax_reduction_group(self, amax_reduction_group: Optional[Any]) -> Optional[Any]: + """Resolve the effective amax reduction group for a quantize call. + + The group passed in (already resolved from the per-call arg and the + destination tensor) takes precedence. If none is given, fall back to the + deprecated group stored on the quantizer, canonicalizing it. + """ + if amax_reduction_group is not None: + return amax_reduction_group + if getattr(self, "with_amax_reduction", False): + return canonicalize_process_group(self.amax_reduction_group) + return None + + def __getstate__(self) -> dict: + """Exclude the unpicklable process group from serialized state.""" state = self.__dict__.copy() state["amax_reduction_group"] = None return state @@ -256,13 +290,14 @@ def copy(self) -> Float8CurrentScalingQuantizer: device=0, rowwise=self.rowwise_usage, columnwise=self.columnwise_usage, - with_amax_reduction=self.with_amax_reduction, - amax_reduction_group=self.amax_reduction_group, force_pow_2_scales=self.force_pow_2_scales, amax_epsilon=self.amax_epsilon, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm + # Propagate deprecated amax reduction fallback (no warning on internal copies) + quantizer.with_amax_reduction = self.with_amax_reduction + quantizer.amax_reduction_group = self.amax_reduction_group return quantizer @@ -272,6 +307,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[dist_group_type] = None, ) -> QuantizedTensor: if not isinstance(dst, Float8Tensor): raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor") @@ -282,17 +318,28 @@ def update_quantized( if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel. Resolve the amax reduction group with priority: + # per-call arg > group stored on the destination tensor (e.g. by FSDP2) > + # deprecated group stored on the quantizer. + if amax_reduction_group is None: + amax_reduction_group = getattr(dst, "amax_reduction_group", None) + amax_reduction_group = self._resolve_amax_reduction_group(amax_reduction_group) + tex.quantize(src, self, dst, noop_flag, amax_reduction_group) # Update FP8 dtype dst._fp8_dtype = self.dtype return dst - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[dist_group_type] = None, + ) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + amax_reduction_group = self._resolve_amax_reduction_group(amax_reduction_group) + return tex.quantize(tensor, self, amax_reduction_group=amax_reduction_group) def calibrate(self, tensor: torch.Tensor) -> None: # current scaling don't need to calibrate @@ -365,10 +412,6 @@ def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: out = out.to(tensor.dtype) return out - def _canonicalized_amax_reduction_group(self) -> dist_group_type: - """Get process group for amax reduction""" - return canonicalize_process_group(self.amax_reduction_group) - def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8CurrentScaling @@ -411,6 +454,11 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ + # Optional process group for the amax all-reduce during in-place + # (re)quantization. Read by ``update_quantized``/``quantize_``/``_set_data`` + # (set e.g. by FSDP2 in ``fsdp_pre_all_gather``); ``None`` means no reduction. + amax_reduction_group: Optional[dist_group_type] = None + def __repr__(self, *, tensor_contents=None): return ( "Float8Tensor(" @@ -787,10 +835,9 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None: # When sharded weight is updated after reduce scattering the gradients in FSDP2, # we need to do amax reduction across the mesh to make sure all weight shards are - # updated with same scale inverse. Setting the state below in the quantizer will make - # sure that updated Quantized weight tensor have same scale inverse across all shards. - self._quantizer.amax_reduction_group = mesh.get_group() - self._quantizer.with_amax_reduction = True + # updated with same scale inverse. The reduction group is stored on the tensor and + # read when the data is re-quantized in ``_set_data``. + self.amax_reduction_group = mesh.get_group() fsdp_state = _get_module_fsdp_state(module) param_group = fsdp_state._fsdp_param_group @@ -995,7 +1042,11 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" self._quantizer.internal = False - self.data = self._quantizer.quantize(tensor) + # FSDP2 sets ``amax_reduction_group`` on this tensor in + # ``fsdp_pre_all_gather`` so that all weight shards get the same scale inverse. + self.data = self._quantizer.quantize( + tensor, amax_reduction_group=getattr(self, "amax_reduction_group", None) + ) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d759aaf5c4..b7fd1a8197 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -63,6 +63,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument ) -> QuantizedTensor: assert isinstance(dst, MXFP8Tensor), f"Cannot store quantized MXFP8 in {type(dst)} type." @@ -81,7 +82,12 @@ def update_quantized( return dst - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument + ) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5a2765b9f5..ce621545bb 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -7,7 +7,7 @@ from collections.abc import Iterable import math import warnings -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import functools import torch @@ -111,15 +111,18 @@ def get_random_sign_mask_for_rht(with_random_sign_mask: bool, device: int) -> in class NVFP4Quantizer(Quantizer): - """Builder class for NVFP4 tensors with NV block scaling""" + """Builder class for NVFP4 tensors with NV block scaling + + Note: ``with_amax_reduction`` and ``amax_reduction_group`` are deprecated. + Prefer supplying the amax reduction process group per call to :meth:`quantize`, + or setting it on the destination tensor for in-place :meth:`update_quantized` / + :meth:`quantize_`. A group stored on the quantizer is still honored as a fallback. + """ dtype: DType """Random Hadamard Transform""" with_rht: bool with_post_rht_amax: bool - """amax reduction options""" - with_amax_reduction: bool - amax_reduction_group: Optional[dist_group_type] """2D block scaling, only applicable for weights.""" with_2d_quantization: bool @@ -138,15 +141,16 @@ class NVFP4Quantizer(Quantizer): """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int - rht_matrix: torch.Tensor + + """Deprecated amax reduction state (kept only for backward compatibility)""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] def __init__( self, fp4_dtype: Union[DType, tex.DType] = DType.kFloat4E2M1, rowwise: bool = True, columnwise: bool = True, - with_amax_reduction: bool = False, - amax_reduction_group: Optional[dist_group_type] = None, with_rht: bool = False, with_post_rht_amax: bool = False, with_2d_quantization: bool = False, @@ -156,13 +160,25 @@ def __init__( nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", with_random_sign_mask: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[Any] = None, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) + if with_amax_reduction or amax_reduction_group is not None: + warnings.warn( + "Storing `with_amax_reduction`/`amax_reduction_group` on NVFP4Quantizer is " + "deprecated; pass the reduction group per call to `quantize`, or set it on " + "the destination tensor for in-place `update_quantized`/`quantize_`. The " + "stored group is still honored as a fallback for now, but this support will " + "be removed.", + DeprecationWarning, + stacklevel=2, + ) + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group self.dtype = DType.cast(fp4_dtype) self.with_rht = with_rht self.with_post_rht_amax = with_post_rht_amax - self.with_amax_reduction = with_amax_reduction - self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 @@ -173,13 +189,36 @@ def __init__( self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") + # The RHT matrix and its sign mask are derived from ``with_random_sign_mask`` + # and the device via module-level ``functools.lru_cache`` helpers, so they + # are effectively process-global. We keep only the (lightweight) seed flag on + # the quantizer and expose the matrix through the ``rht_matrix`` property, + # rather than storing a per-quantizer tensor. + self._with_random_sign_mask = with_random_sign_mask self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) - self.rht_matrix = get_rht_matrix(with_random_sign_mask, torch.cuda.current_device()) - def __getstate__(self): - """Exclude unpicklable process group from serialized state.""" + @property + def rht_matrix(self) -> torch.Tensor: + """RHT matrix (fetched from the process-global cache, not stored per quantizer).""" + return get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) + + def _resolve_amax_reduction_group(self, amax_reduction_group: Optional[Any]) -> Optional[Any]: + """Resolve the effective amax reduction group for a quantize call. + + The group passed in (already resolved from the per-call arg and the + destination tensor) takes precedence. If none is given, fall back to the + deprecated group stored on the quantizer, canonicalizing it. + """ + if amax_reduction_group is not None: + return amax_reduction_group + if getattr(self, "with_amax_reduction", False): + return canonicalize_process_group(self.amax_reduction_group) + return None + + def __getstate__(self) -> dict: + """Exclude the unpicklable process group from serialized state.""" state = self.__dict__.copy() state["amax_reduction_group"] = None return state @@ -190,6 +229,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + amax_reduction_group: Optional[Any] = None, ) -> QuantizedTensor: assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." @@ -200,8 +240,14 @@ def update_quantized( if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel. The amax reduction process group is used for the + # second-level per-tensor amax all-reduce in TP/SP. Resolve with priority: + # per-call arg > group on the destination tensor (e.g. by FSDP2) > + # deprecated group stored on the quantizer. + if amax_reduction_group is None: + amax_reduction_group = getattr(dst, "amax_reduction_group", None) + amax_reduction_group = self._resolve_amax_reduction_group(amax_reduction_group) + tex.quantize(src, self, dst, noop_flag, amax_reduction_group) return dst @@ -212,8 +258,6 @@ def copy(self) -> NVFP4Quantizer: fp4_dtype=self.dtype, rowwise=self.rowwise_usage, columnwise=self.columnwise_usage, - with_amax_reduction=self.with_amax_reduction, - amax_reduction_group=self.amax_reduction_group, with_rht=self.with_rht, with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, @@ -222,17 +266,29 @@ def copy(self) -> NVFP4Quantizer: nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, + with_random_sign_mask=self._with_random_sign_mask, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm - quantizer.rht_matrix = self.rht_matrix - quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t + # Propagate deprecated amax reduction fallback (no warning on internal copies) + quantizer.with_amax_reduction = self.with_amax_reduction + quantizer.amax_reduction_group = self.amax_reduction_group return quantizer - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: - """Quantize tensor implementation""" - return tex.quantize(tensor, self) + def quantize_impl( + self, + tensor: torch.Tensor, + *, + amax_reduction_group: Optional[Any] = None, + ) -> QuantizedTensor: + """Quantize tensor implementation + + The optional ``amax_reduction_group`` is supplied per call and forwarded + to the fused cast kernel for the TP/SP per-tensor amax all-reduce. + """ + amax_reduction_group = self._resolve_amax_reduction_group(amax_reduction_group) + return tex.quantize(tensor, self, None, None, amax_reduction_group) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" @@ -318,10 +374,6 @@ def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: def calibrate(self, tensor: torch.Tensor) -> None: pass # Calibration is no-op - def _canonicalized_amax_reduction_group(self) -> dist_group_type: - """Get process group for amax reduction""" - return canonicalize_process_group(self.amax_reduction_group) - def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling @@ -359,6 +411,11 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): Nominal tensor datatype, used in dequantize. """ + # Optional process group for the amax all-reduce during in-place + # (re)quantization. Read by ``update_quantized``/``quantize_``/``_set_data`` + # (set e.g. by FSDP2 in ``fsdp_pre_all_gather``); ``None`` means no reduction. + amax_reduction_group: Optional[dist_group_type] = None + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( @@ -513,6 +570,13 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m "FSDP2 is not supported for NVFP4Tensors with GEMM-swizzled scales." ) + if mesh is not None: + # When the sharded weight is re-quantized after the FSDP2 optimizer + # step, every shard must derive the same global (per-tensor) amax so + # that the all-gathered scale (passed via metadata) is consistent. + # The reduction group is stored on the tensor and read in ``_set_data``. + self.amax_reduction_group = mesh.get_group() + shard_M = math.prod(self.shape[:-1]) assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, ( @@ -890,6 +954,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" + # FSDP2 sets ``amax_reduction_group`` on this tensor in + # ``fsdp_pre_all_gather`` so that all weight shards get the same (global) + # scale; ``update_quantized`` reads it from the destination tensor. self._quantizer.update_quantized(tensor, self) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad)