Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 23 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 23 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet

@negvet negvet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

Description

Hybrid (per-direction) quantization. Functional.
C++ optimizations (fusions, etc.) will come in the next PRs.

TODO:

  • Double quantization
  • Non-hybrid convergence of base recipes (validation)
  • DCP/torch_dist doesn't preserve hybrid current-scaling primary-weight _scale_inv; fix via __tensor_flatten__/__tensor_unflatten__ across the tensor stack — torch.save/load and FP32-master paths unaffected; covered by the test_hybrid_dcp_output_parity xfail.

Integration

Ecosystem integration (all functional, unit-tested):

  • [Done] quantized_model_init
  • [Done] FSDP2 (TODO: optimize communication buffers)
  • [Done] CPU offloading
  • [Done] Activation recomputation
  • [Done] TP/SP (TODO: enable quantized AG)

Megatron-LM integration status:

  • [Done] 1 GPU baseline
  • [Done] DP + distributed optimizer
  • [TODO] quantized_model_init + --fp{4,8}-param-gather + dist opt (persistent low-precision params via quantized_model_init + sharded-master FP32 → quantized cast via quantize_master_weights.)
    - [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
    including same-format, cross-format Float8, single-direction)
    - [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by quantize_master_weights; unblocker is TE-side cast-helper / kernel.
  • [TODO] Megatron-FSDP + --fp{4,8}-param-gather (fix private attribute access)
  • [TODO] Torch FSDP2 + --fp{4,8}-param-gather
    - [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
    - [TODO] NVFP4 sub-storage FSDP2 hooks
  • [Done] Activation recompute
  • [Done] CPU offload
  • [Done] TP/SP/PP
  • [Done] MoE + EP + grouped GEMM (qwen3 MoE; _hybrid_split_quantize under Megatron MoE)

Review

Total diff +9000
New hybrid source (hybrid_tensor.py, hybrid_tensor_storage.py) ~1000
Adjacent modifications ~1000
Tests are the rest

Surface to review is ~2000 lines

Suggested reading order

  1. Foundation — 7553e6a: Python containers + quantize/gemm dispatch/unwrap
  • tensor/hybrid_tensor.py — HybridQuantizer + HybridQuantizedTensor
  • tensor/storage/hybrid_tensor_storage.py
  • cpp_extensions/gemm.py — _unwrap_hybrid_A/B
  • common/transpose/quantize_transpose_square_blockwise.cu - Block FP8 columnwise-only null-checks
  • Module hooks in module/{base,grouped_linear,layernorm_linear,layernorm_mlp}.py
  • Tests: TestHybridQuantizer*, TestHybridGemmBitwiseIdentical* (proves zero-overhead vs vanilla recipes when both formats match), TestHybridDirectionUnwrap*, TestHybridGroupedLinear*
  1. quantized_model_init + FusedAdam — f80f5d0
  • hybrid_tensor.py::HybridQuantizer.update_quantized — delegates to each sub-quantizer; unblocks workspace-cache quantize_() and FusedAdam writeback
  • module/base.py workspace-cache invalidation
  • Tests: TestHybridQuantizedModelInit, TestHybridFusedAdam, TestHybridQuantizedParamsEndToEnd, TestHybridCheckpoint, TestQuantizedParamsEquivalence*
  1. FSDP2 support — 2185b30
  • New base FSDP2 buffer protocol on QuantizedTensorStorage: fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered. Generic, reusable beyond hybrid.
  • Per-format overrides on Float8TensorStorage (direction-aware) and MXFP8TensorStorage (trips/re-applies scale alignment padding around the gather)
  • hybrid_tensor.py::fsdp_pre/post_all_gather + torch_dispatch for the FSDP2 op set (view, split, as_strided, slice, copy_, new_zeros, clone, detach)
  • Non-safety in float8_tensor.py and mxfp8_tensor.py for single-direction sub-storages (columnwise-only on Hopper/L40)
  • Tests: TestHybridTorchDispatchFSDP2Ops, TestHybridFsdpPreAllGatherProtocol, TestHybridFsdpRoundtrip (bitwise-exact against manual all_gather(dequantize(shard))), plus tests/pytorch/distributed/fsdp2_tests/
  1. CPU offloading — 103fffe
  • hybrid_tensor_storage.py::clear() (v1 path) + prepare_for_saving / restore_from_saved chain (v2 path)
  • hybrid_tensor.py::detach() re-wraps each sub-storage via make_like (required by cpu_offload_v2's detach → prepare_for_saving pattern; sharing sub-storage objects would null-out fields on the original)
  • TestHybridCpuOffloadPushPop, plus updates to test_cpu_offloading*.py
  1. Activation recomputation — 16fb371
  • Uses existing QuantizedTensorStorage::prepare_for_saving / restore_from_saved protocol, preserving ordering across both sub-storages
  • Tests: 20 bitwise tests in TestHybridActivationRecompute
  1. TP/SP — a50fd63
  • hybrid_tensor.py::HybridQuantizer.supports_only_rowwise_all_gather — overrides to handle the NVFP4 columnwise-dequantize gap in the BF16 fallback path
  • distributed.py::gather_along_first_dim — hybrid branch re-quantizes with both directions after AG (since hybrid has no _create_transpose synthesis path)
  • Tests: 9 distributed tests in run_hybrid_tp_sp.py / test_hybrid_tp_sp.py
  1. Megatron-LM integration — a164cd3
  • tensor/utils.py::_route_hybrid_to_buckets — per-direction dispatch for quantize_master_weights: iterates both sub-storages, routes each independently into the per-format bucket matching its own sub-quantizer type
  • Hybrid branches in replace_raw_data and post_all_gather_processing
  • Today: per-tensor Float8 sub-quantizers (delayed + current) work in any per-direction combination. Per-block sub-quantizers raise per-direction with in-code TODOs naming the unblocker.
  • Tests: TestHybridQuantizeMasterWeights, TestHybridPostAllGatherProcessing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces hybrid (per-direction) quantization for PyTorch, enabling a weight to use one quantization format for its rowwise direction and a different format for its columnwise direction. The implementation spans ~2000 lines of new source across a new HybridQuantizer/HybridQuantizedTensor/HybridQuantizedTensorStorage stack, GEMM dispatch unwrapping helpers, and extensions to FSDP2, CPU offloading, activation recompute, TP/SP, and the distributed optimizer.

  • New core types (hybrid_tensor.py, hybrid_tensor_storage.py): HybridQuantizer composes two sub-quantizers pinned to their respective directions; HybridQuantizedTensor wraps two sub-storages and implements the full __torch_dispatch__ surface required by FSDP2 (split, clone, copy_, new_zeros, as_strided, slice, view).
  • FSDP2 buffer protocol (quantized_tensor.py, per-format storage overrides): New fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered protocol allows each sub-storage to declare and transform its own gather buffers independently; Float8Blockwise handles the columnwise→M-major transpose, MXFP8 handles scale-alignment unpadding/repadding.
  • CUDA kernel null-checks (quantize_transpose_square_blockwise.cu): output_c and tile_scales_inv_c are now guarded for nullptr to support columnwise-only hybrid sub-storages that do not write rowwise output.

Confidence Score: 2/5

The PR contains a method naming collision in Float8BlockwiseQTensorStorage that makes FSDP2 all-gather crash immediately for any hybrid Float8Blockwise weight, and a split-dispatch bug for non-square hybrid MXFP8 columnwise sub-storages.

In float8_blockwise_tensor_storage.py, the method intended as fsdp_extract_buffers was given the name fsdp_buffer_fields, silently overwriting the correct field-name method defined three lines earlier. Every call to fsdp_buffer_fields() on a Float8BlockwiseQTensorStorage — including the pre-flight check in HybridQuantizedTensor.fsdp_pre_all_gather — immediately recurses infinitely and raises RecursionError. Separately, the aten.split.Tensor dispatch in HybridQuantizedTensor.__torch_dispatch__ forwards the M-derived split_size at dim=0 to MXFP8 columnwise sub-storages stored in [K, M] layout; for FFN-style non-square weights where K ≠ M the split produces a different number of column pieces than row pieces, causing an IndexError when assembling the per-shard HybridQuantizedTensor list.

transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py (duplicate method name) and transformer_engine/pytorch/tensor/hybrid_tensor.py (split dispatch for transposed columnwise sub-storages).

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds FSDP2 buffer protocol for Float8Blockwise sub-storages, but fsdp_extract_buffers is accidentally named fsdp_buffer_fields, silently shadowing the correct field-name method and causing infinite recursion on any call.
transformer_engine/pytorch/tensor/hybrid_tensor.py New HybridQuantizer + HybridQuantizedTensor implementation; aten.split.Tensor dispatch passes the M-derived split_size to transposed [K,M] MXFP8 columnwise sub-storages, producing wrong piece counts for non-square weights.
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py New HybridQuantizedTensorStorage mixin; correctly implements FSDP2 protocol, CPU offload, and activation recompute delegation to sub-storages.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds direction-aware fsdp_buffer_fields (correctly returns _transpose for columnwise-only sub-storages), fsdp_assign_gathered with _transpose_invalid reset, and a null guard in _create_transpose for hybrid sub-storages with _data=None.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds FSDP2 buffer protocol for MXFP8; handles rowwise-only and columnwise-only cases including scale unpadding/repadding around the all-gather.
transformer_engine/pytorch/module/grouped_linear.py Adds _is_hybrid_quantizer_list and _hybrid_split_quantize helpers; correctly routes forward/backward grouped GEMMs through the hybrid two-pass quantize path.
transformer_engine/pytorch/tensor/utils.py Adds _route_hybrid_to_buckets for distopt per-direction routing and hybrid branches in replace_raw_data / post_all_gather_processing; per-block sub-quantizers raise NotImplementedError with clear TODO blockers.
transformer_engine/pytorch/distributed.py Hybrid override in gather_along_first_dim saves/restores quantizer usage around a full-direction re-quantization, avoiding the broken synthesis path that vanilla float8 uses post-BF16 AG.
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Adds output_c != nullptr / tile_scales_inv_c != nullptr null checks to support columnwise-only mode needed by hybrid sub-storages that skip rowwise output.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _unwrap_hybrid_A/B helpers that extract the direction-appropriate sub-storage from a HybridQuantizedTensorStorage before dispatching to the C++ GEMM; no-op for non-hybrid tensors.

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class HybridQuantizer {
        +rowwise_quantizer: Quantizer
        +columnwise_quantizer: Quantizer
        +quantize_impl(tensor) HybridQuantizedTensor
        +make_empty(shape) HybridQuantizedTensor
        +update_quantized(src, dst) HybridQuantizedTensorStorage
        +supports_only_rowwise_all_gather() bool
    }
    class HybridQuantizedTensor {
        +_rowwise_storage: QuantizedTensorStorage
        +_columnwise_storage: QuantizedTensorStorage
        +fsdp_pre_all_gather()
        +fsdp_post_all_gather()
        +detach() HybridQuantizedTensor
        +dequantize() Tensor
    }
    class HybridQuantizedTensorStorage {
        +_rowwise_storage
        +_columnwise_storage
        +update_usage()
        +clear()
        +prepare_for_saving()
        +restore_from_saved()
        +dequantize()
    }
    class QuantizedTensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered()
    }
    class Float8TensorStorage {
        +fsdp_buffer_fields() - direction-aware
        +fsdp_assign_gathered() - resets _transpose_invalid
    }
    class MXFP8TensorStorage {
        +fsdp_buffer_fields() - rowwise+colwise aware
        +fsdp_extract_buffers() - strips scale padding
        +fsdp_assign_gathered() - repads scale
    }
    class Float8BlockwiseQTensorStorage {
        +fsdp_buffer_fields() - 2D only
        +fsdp_extract_buffers() - transposes colwise to M-major
        +fsdp_assign_gathered() - transposes back
    }
    HybridQuantizer --> HybridQuantizedTensor : produces
    HybridQuantizedTensor --|> HybridQuantizedTensorStorage
    HybridQuantizedTensorStorage --|> QuantizedTensorStorage
    Float8TensorStorage --|> QuantizedTensorStorage
    MXFP8TensorStorage --|> QuantizedTensorStorage
    Float8BlockwiseQTensorStorage --|> QuantizedTensorStorage
    HybridQuantizedTensorStorage "1" o-- "0..1" Float8TensorStorage : rowwise_storage
    HybridQuantizedTensorStorage "1" o-- "0..1" Float8TensorStorage : columnwise_storage
Loading

Reviews (9): Last reviewed commit: "Enable FSDP2 hybrid protocol for Float8B..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py
negvet and others added 2 commits April 29, 2026 16:02
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
negvet added 3 commits May 13, 2026 12:34
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner May 21, 2026 13:53
Comment on lines 665 to 677
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 When _data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.

Suggested change
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=(
split_tensor.shape
if split_tensor is not None
# _transpose has shape [K, M/n] but the shard's nominal shape
# is [M/n, K]. Recover the correct shard shape by reversing
# the last two dims of the transposed piece.
else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0])
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Comment on lines +27 to +30
# DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories
# (lambdas, inner functions referencing captured state) are not picklable,
# so the qfactory must live at module scope. See
# ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1177 to +1184
for param in model.parameters():
state = optimizer.state[param]
assert state["exp_avg"].dtype == torch.float32
assert state["exp_avg_sq"].dtype == torch.float32
if "master_param" in state:
assert state["master_param"].dtype == torch.float32

assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a very strict test, is there a way for us to do some numerical correctness comparisons?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.

# Quantized training may diverge from bf16, but should not be wildly different.
for step, (h_loss, b_loss) in enumerate(zip(hybrid_losses, bf16_losses)):
ratio = h_loss / max(b_loss, 1e-10)
assert 0.1 < ratio < 10.0, (

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... What are the actual values there? Could we maybe set a seed or something and compare with some more reasonable tolerance?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a few % from bf16, tightened the tolerance, please take a look

"HybridMXFP8": dict(rtol=0.0, atol=0.0),
"HybridMixed_MXFP8_FP8": dict(rtol=0.0, atol=0.0),
}
tolerance = _TIGHT_TOLERANCE.get(hybrid_recipe_name, dict(rtol=1e-6, atol=1e-6))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, again, what are the actual values here? I understand the general idea here that when the format is stateful then you will have some difference, but I don't think that 1e-6 would be the right tolerance if that difference actually happened. So if we do not actually test the case that would exhibit this issue then maybe it would be better to just set the tolerances to 0 in all cases to simplify the test code? This would then also be a clear point of failure if somebody added here a recipe that would not exhibit this same property.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I agree, just setting to 0.

Comment on lines +1403 to +1404
assert per_rank_out % 32 == 0, "MXFP8 data alignment precondition"
assert per_rank_out % 128 != 0, "Test precondition: shard must need scale padding"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those error messages, since they are purely meant for the person changing the test itself, could be more descriptive in the error message.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1462 to +1482
if hybrid_recipe_name == "HybridFloat8BlockScaling":
pytest.xfail(
"HybridFloat8BlockScaling: Float8BlockwiseQTensor sub-storage loses "
"quantized type through FSDP2 view(-1)."
)

if hybrid_recipe_name == "HybridFP8CurrentScaling":
pytest.xfail(
"HybridFP8CurrentScaling: per-tensor _scale_inv is not preserved "
"through DCP's tensor-storage-byte serialization path. "
"HybridQuantizedTensor.__reduce_ex__ correctly round-trips through "
"pickle (verified by torch.save/torch.load), but DCP bypasses "
"pickle and serializes the tensor's storage bytes — the scalar "
"_scale_inv is not enumerated as a separate tensor leaf and gets "
"lost. Vanilla Float8CurrentScaling avoids this because per-tensor "
"scale lives in module.fp8_meta (saved as extra_state), not on "
"the tensor; hybrid uses per-sub-storage scales without that "
"mirror. Fix path: implement __tensor_flatten__/__tensor_unflatten__ "
"across the quantized tensor stack so DCP can serialize the "
"per-leaf tensor fields directly. Loaded model output diverges by "
"~5e-2."

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, do we intend to do something about that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a proper fix would touch all tensors, so let's do that in a separate PR, added a TODO at the description above.


def _build_hybrid_model(num_layers, hybrid_recipe, use_meta_device=True):
"""Build a model with quantized_model_init using a hybrid CustomRecipe."""
ctx = te.quantized_model_init(enabled=True, recipe=hybrid_recipe)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This quite strange choice (it was also in the other test files) to separate the ctx definition from the usage. It is not a big deal (basically a nit), but it looks strange - it is better to have things close to the actual usage site if they are not too big.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +490 to +494
if hybrid_recipe_name == "HybridFloat8BlockScaling":
pytest.xfail(
"HybridFloat8BlockScaling: Float8BlockwiseQTensor sub-storage loses "
"quantized type through FSDP2 view(-1)."
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any plans to address it? Is it a limitation of the underlying recipe? @vthumbe1503 any thoughts here?

@vthumbe1503 vthumbe1503 Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the QuantizedTensors have view(-1) implemented to return the dequantized output, so there should be something more going on in Float8BlockScaling that might be causing the test to fail here.

Also the view(-1) in FSDP2 is done to store a sharded tensor just for checkpointing logic and isnt used anywhere.

@negvet negvet Jun 3, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insight @vthumbe1503

This is a hybrid-related bug.
After fixing the intermediate view-related bug, it turns out that Float8BlockwiseQTensorStorage does not implement the FSDP2 sub-storage protocol (fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered) required by hybrid all-gather.

Fixing it, WIP.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybrid_avg = sum(hybrid_increments) / len(hybrid_increments)

excess_per_layer = hybrid_avg - bf16_avg
tolerance_per_layer = 50 * 1024 # 50 KiB

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the base of that tolerance? What are the actual values?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment that replies:

# Basis: forward growth is constant per layer (no accumulation) for both bf16 and
# hybrid; the excess is just hybrid's extra per-layer quantized buffers. Measured
# excess: ~3 KiB (FP8 current) / ~7 KiB (mixed MXFP8+FP8) / ~12 KiB (MXFP8). A
# leaked layer's quantized weights would be hundreds of KiB, so 50 KiB sits above
# the real per-layer overhead and well below a leak.

)

excess = hybrid_bwd_delta - bf16_bwd_delta
tolerance = 256 * 1024 # 256 KiB

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one larger than the previous one?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

50 is a per-layer forward, whereas 256 is a whole-model backward + optimizer step

loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
dist_print(f"Hybrid iteration {iteration} completed with loss {loss.item()}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this test actually checking? There are no assertions here - if we only check if it does not crash then what is the value of this test vs the other ones?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this was just a smoke test. Updated it to assert loss finiteness + strict monotonic decrease, hybrid-type preservation across the optimizer step, and FSDP2 all-gather correctness vs a manual fp32 dequant-then-allgather (the check test_distributed already had and this one was missing).

loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
dist_print(f"Hybrid reshard_after_fwd iter {iteration}, loss {loss.item():.4f}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Comment on lines +18 to +25
* ``te.Linear`` column-parallel and row-parallel, with and without
sequence parallelism.
* ``te.LayerNormLinear`` column-parallel with sequence parallelism —
the quantized-AG path that currently unfuses LN+quantize for
``HybridQuantizer``.
* ``te.TransformerLayer`` with ``set_parallel_mode=True`` and SP on —
integration test hitting LayerNormLinear + DPA + LayerNormMLP + row-
parallel output projection in one shot.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that Transformer layer gives basically everything, what is the value of the other tests? And if there is value in the other tests, then why don't we check the LayerNormMLP on its own?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests are complementary, not redundant. Transformer layer test is the broad smoke. Standalone tests have additional value (grad checks + extra configs). Following this, adding LayerNormMLP. Also added a bitwise hybrid-vs-vanilla equivalence test on te.Linear.

Comment on lines +29 to +31
numerical signal is clean. Cross-format hybrid adds independent
numerical variation unrelated to TP/SP and is covered by single-GPU
tests already.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I agree with this assessment. Cross format is actually the only case where you need to be careful about the allgather cases being different in forward and backward and allgather touches the comm-gemm overlap that would also be potentially affected (e.g. due to wrong buffer sizes taken from forward recipe being used in backward).

Also, a general comment - hybrid recipes that do forward quantized/backward unquantized and vice versa would be very useful to test as well.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, you are right, AG is format dependent. Adding MXFP8 forward / NVFP4 backward. and updating the docstring.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a general comment - hybrid recipes that do forward quantized/backward unquantized and vice versa would be very useful to test as well.

Looking into it.

Comment on lines +33 to +36
Tolerances match upstream ``run_numerics.py`` per-format settings (see
``_get_tolerances``); they're loose enough to absorb the amax-reduction
and stochastic numerical asymmetries inherent to distributed FP8, tight
enough to catch a silent BF16 fallback on the hybrid sub-storage path.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the tolerances be effectively 0 if you are doing the non-actually-hybrid recipes only? Since you should be comparing the same underlying implementations.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tolerances are for the distributed vs single-node comparison, not hybrid vs vanilla. I will reword this comment to make the two comparisons distinct.

columnwise_quantizer=_make_mxfp8_quantizer(),
)
if is_linear and role.tensor_type in ("grad_output", "grad_input"):
return _make_mxfp8_quantizer(fp8_dtype=tex.DType.kFloat8E5M2)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E5M2 here is not correct. In general this should just be a single line to return the mxfp8 quantizer for all cases.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, fixed

Comment on lines +126 to +131
"""Default NVFP4Quantizer: no RHT, no stochastic rounding, no 2D
scaling — matches upstream ``run_numerics.py::nvfp4_vanilla()`` which
strips the recipe to bare 1D block scaling for distributed TP
fairness. Same-format hybrid NVFP4 has no E5M2 variant (NVFP4 is a
single format family — E2M1 only), so grad roles reuse the same
NVFP4 quantizer."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we want to check the full recipe here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the full recipe except 1D for weights, will enabled after #3027 merge

Comment on lines +136 to +143
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=_make_nvfp4_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if is_linear and role.tensor_type in ("grad_output", "grad_input"):
return _make_nvfp4_quantizer()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written those lines are not needed at all. They would be needed if you did the full recipe.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the full recipe

Comment on lines +166 to +168
# quantization (rowwise and columnwise quantizers run independently, so
# their outputs may differ by ~1 ULP from a single fused-quantize path
# in edge cases).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does not sound like a good thing if it actually happens in practice - the quantization only should not be affected if you do both at the same time vs one at a time -> the input and the algorithm is the same in both cases. Fusion with the activations could maybe give slightly different results, but I would still like to get an explanation of why that would be.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. If the algorithm is the same, we are indeed getting identical results. New bitwise linear_vs_vanilla test confirms this. The only place where two pass and fused differ is NVFP4 with RHT + SR. This activates a separate columnwise RNG (need_separate_columnwise_rng), and RNG stream consumed differently. see comment in _backward_not_bitwise_comparable(). Removed the comment.

negvet and others added 4 commits June 1, 2026 08:47
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants