Skip to content

[PyTorch][torch.compile] Remove process group from quantizers#3104

Open
pggPL wants to merge 13 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers
Open

[PyTorch][torch.compile] Remove process group from quantizers#3104
pggPL wants to merge 13 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers

Conversation

@pggPL

@pggPL pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Description

This makes adding torch.compile support much easier.
Move amax reduction process group handling out of quantizer state and pass it per quantization call instead. This avoids storing process groups inside quantizers while keeping deprecated stored-group fallback behavior for compatibility.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality not to work as expected)
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Pass amax_reduction_group through quantize/module call paths instead of storing it on quantizers.
  • Preserve deprecated constructor/state fallback for existing callers, excluding process groups from serialization.
  • Update FP8/NVFP4/MXFP8/blockwise tensor quantization paths and C++ bindings to resolve reduction groups per call.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 8 commits June 3, 2026 17:56
…ing on quantizer

Remove the persistent amax reduction process group from quantizer objects
(Float8CurrentScalingQuantizer, NVFP4Quantizer) and thread it as a call-time
argument through quantize/quantize_impl/update_quantized down to the fused
C++ kernels (tex.quantize, layernorm_fwd, rmsnorm_fwd). The C++ constructors
no longer read with_amax_reduction/amax_reduction_group from Python; the group
is attached transiently to the C++ quantizer per call.

All call sites (linear, layernorm_linear, layernorm_mlp, ops/basic_linear,
base.grad_output_preprocess, distributed gather, attention context-parallel,
FSDP) now compute and forward the reduction group explicitly.

Backward compatibility: with_amax_reduction/amax_reduction_group are still
accepted (constructor or attribute) and honored as a fallback when no per-call
group is given, with a DeprecationWarning, preserving 1:1 numerics for legacy
callers. Also drop the per-quantizer RHT matrix tensor in favor of a property
backed by the process-global cache.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…uantizer

Clean up explanatory comments and docstrings that narrated the migration of
the amax reduction process group off the quantizer (e.g. "passed per call
instead of being stored on the quantizer", "no longer stored on the Python
quantizer"). No functional changes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
… state

Thread the amax reduction process group as a per-call argument to the C++
Quantizer::quantize / quantize_impl / quantize_with_amax methods and drop the
with_amax_reduction / amax_reduction_group members from the C++ quantizer
objects. The group is converted once via convert_amax_reduction_group and
forwarded to the kernels, so the quantizer carries no distributed state.

Also remove an orphaned comment in attention backends that described amax
reduction group handling happening elsewhere.

No public API change: these are internal transformer_engine::pytorch symbols;
the pybind-exposed quantize already takes amax_reduction_group.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The grouped NVFP4 quantize-with-amax path never receives a reduction group
(the FC2 input is a local per-token activation, not an all-gathered shard), so
allreduce_nvfp4_amax_tensors was always a no-op. Drop the helper and the
amax-tensor list it consumed.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop dangling comments at the end of the per-module _customize_quantizers
if/else blocks that only narrated that amax reduction is "supplied per quantize
call" (and a "no amax reduction here" aside). They annotated no code and were
migration artifacts; the if/else twins missed in the first cleanup pass.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…s.partial

Pass amax_reduction_group as an explicit argument to the _QuantizeFunc autograd
function (and forward it to quantize_impl) rather than binding it via a
functools.partial on every quantize() call. Avoids a per-call allocation on the
hot path and is friendlier to torch.compile.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the "(not on the quantizer, which must stay free of distributed state)"
aside in fsdp_pre_all_gather, matching the earlier trim of the same phrase in
the _stash_fsdp_amax_group docstring. The remaining comment still explains why
the reduction group is stashed on the tensor.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps

greptile-apps Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This refactoring moves the amax reduction process group out of quantizer state and passes it per quantization call instead, eliminating mutable distributed state on quantizers and making torch.compile support significantly easier. Deprecated stored-group constructors are preserved with DeprecationWarning for backward compatibility, and process groups are excluded from pickle serialization.

  • Core change: amax_reduction_group is now a keyword argument to Quantizer.quantize, quantize_impl, update_quantized, and __call__ across Python and C++ layers; call sites in linear, layernorm_linear, layernorm_mlp, attention CP, and distributed gather functions derive the group inline from the local parallelism context.
  • FSDP2: the group is stored directly on Float8Tensor and NVFP4Tensor as a class attribute (set in fsdp_pre_all_gather), read by _set_data / update_quantized, eliminating the stash/pop pattern and the quantizer mutations that blocked torch.compile.
  • _QuantizeFunc backward: correctly updated to return an extra None matching the new forward argument, avoiding a gradient-count mismatch.

Confidence Score: 5/5

Safe to merge. The refactoring is mechanically consistent across all 25 files: the per-call group threading correctly replaces stored-group mutations, deprecated fallbacks are preserved, and the _QuantizeFunc backward return count is properly updated.

All production quantization paths correctly thread amax_reduction_group per call. The FSDP2 tensor-level attribute design cleanly replaces the old stash/pop patterns. The only issue found is in CurrentScalingQuantizerRef._resolve_amax_reduction_group which omits canonicalize_process_group, affecting only the deprecated fallback path in a reference implementation.

transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds amax_reduction_group parameter to Quantizer.quantize, quantize_impl, update_quantized, and __call__; threads the group through _QuantizeFunc and updates the backward return to match the added forward argument.
transformer_engine/pytorch/tensor/float8_tensor.py Moves FSDP2 amax reduction group from the quantizer to a class-level attribute on Float8Tensor; _set_data reads it via getattr and forwards it to quantize; fsdp_pre_all_gather stores it on the tensor instead of mutating the quantizer.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Refactors NVFP4Quantizer to accept amax_reduction_group per call; converts rht_matrix to a property over the process-global LRU cache; adds _resolve_amax_reduction_group fallback; adds amax_reduction_group class attribute to NVFP4Tensor for FSDP2.
transformer_engine/pytorch/csrc/quantizer.cpp Removes stored with_amax_reduction/amax_reduction_group members from C++ quantizer classes; threads a per-call amax_reduction_group parameter through every quantize and quantize_impl override.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds amax_reduction_group parameter to the quantize Python binding entry point; removes the allreduce_nvfp4_amax_tensors helper and its call from nvfp4_group_quantize_with_amax (no per-call replacement for that function).
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Replaces per-quantizer mutation in backends.py with a per-call fp8_amax_reduction_group computed at forward/backward entry in each CP attention function and threaded through all quantizer call sites.
transformer_engine/pytorch/module/linear.py Removes _customize_quantizers_nvfp4 and the stored-group mutations; derives input_amax_reduction_group / grad_output_amax_reduction_group inline at each call site in forward and backward.
transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py Adds _resolve_amax_reduction_group fallback; inconsistency: unlike the production quantizers, the deprecated-fallback branch returns the stored group directly without calling canonicalize_process_group.
transformer_engine/pytorch/tensor/_quantization_helpers.py Adds amax_reduction_group to _QuantizeFunc.forward and fixes backward to return an extra None matching the new forward argument count.
transformer_engine/pytorch/distributed.py Adds amax_reduction_group parameter to _all_gather_fp8, _all_gather_nvfp4, and gather_along_first_dim; forwards it to every internal quantizer call site.

Reviews (4): Last reviewed commit: "Carry amax reduction group on QuantizedT..." | Re-trigger Greptile

Comment on lines 326 to +329
"""Quantize tensor"""
return self.quantize(tensor)
if amax_reduction_group is None:
return self.quantize(tensor)
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The None guard here is redundant: self.quantize(tensor) and self.quantize(tensor, amax_reduction_group=None) are identical because quantize defaults the argument to None. The branch just adds noise.

Suggested change
"""Quantize tensor"""
return self.quantize(tensor)
if amax_reduction_group is None:
return self.quantize(tensor)
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)
"""Quantize tensor"""
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +200 to +203
@property
def rht_matrix(self) -> torch.Tensor:
"""RHT matrix (fetched from the process-global cache, not stored per quantizer)."""
return get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Deserialization break for old pickled NVFP4Quantizer instances

rht_matrix is now a property that reads self._with_random_sign_mask, but _with_random_sign_mask is a new field that did not exist in pickled state produced before this change. When Python's default __setstate__ (i.e., self.__dict__.update(state)) loads an old pickle, _with_random_sign_mask is absent, so any access to the rht_matrix property raises AttributeError. A __setstate__ that infers _with_random_sign_mask from the old stored rht_matrix (or supplies a safe default) would preserve backward compatibility for serialized quantizers.

pggPL and others added 3 commits June 8, 2026 15:44
The amax-reduction refactor moved the reduction group from quantizer
state to a per-call `amax_reduction_group` argument, but three call
paths were not updated, so they silently skipped amax reduction (or
raised):

- DebugQuantizer.quantize()/update_quantized() did not accept
  `amax_reduction_group`, raising TypeError under sequence parallelism.
- gather_along_first_dim() forwarded the group to the FP8 all-gather
  but not to _all_gather_nvfp4(), so NVFP4 row-parallel + sequence
  parallel quantized with a local amax (exact dgrad mismatch).
- NVFP4Tensor had no FSDP2 amax-group stash/pop (unlike Float8Tensor),
  so FSDP2 weight shards re-quantized with inconsistent global scales.

Mirror the existing Float8 handling for NVFP4 and forward the group
through the NVFP4 all-gather and the debug quantizer.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
FSDP2 re-quantizes a sharded weight after the optimizer step through
QuantizedTensor.__torch_dispatch__ -> quantize_ -> update_quantized (and
via the Tensor.data setter, _set_data). On this branch the amax reduction
group is no longer stored on the quantizer, but quantize_ did not thread a
group through, so the weight shards re-quantized with a local per-rank amax
instead of a global one. The all-gathered FP8/NVFP4 weight then dequantized
to different values per shard, failing _check_fp8_fsdp2_allgather for
Float8CurrentScaling and NVFP4BlockScaling with fp8_init=True.

Add an explicit per-call `amax_reduction_group` argument to the `quantize_`
hierarchy (QuantizedTensorStorage / QuantizedTensor / Float8/NVFP4/MXFP8/
Float8Blockwise), keeping it consistent with quantize()/update_quantized().
When no group is supplied, Float8Tensor/NVFP4Tensor fall back to the group
stashed on the tensor by fsdp_pre_all_gather, so in-place optimizer
re-quantization stays globally scaled without every caller having to thread
the group. The FSDP stash is now kept (read, not popped) and refreshed each
iteration, matching the previous behavior where the group lived on the
quantizer.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

Blocked by FSDP bug, refactor in progress.

I plan to store .amax_reduction_group in QuantizedTensor.

…rough quantize_

Float8Tensor/NVFP4Tensor carry an optional ``amax_reduction_group`` attribute,
set by FSDP2 in fsdp_pre_all_gather and kept (refreshed every iteration) rather
than popped. ``update_quantized`` reads it from the destination tensor when no
group is supplied per call, so in-place re-quantization via quantize_/_set_data
(e.g. the FSDP2 optimizer step) stays globally scaled without threading a group
through the quantize_ signature. The deprecated quantizer-stored group is still
honored as a lowest-priority fallback. The current-scaling reference quantizer
gains the same per-call + destination-tensor support.

Supersedes the alternative of adding an explicit amax_reduction_group argument to
the quantize_ hierarchy.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant