[Pytorch][Common] Hybrid quantization#2817
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid (per-direction) quantization for PyTorch, enabling a weight to use one quantization format for its rowwise direction and a different format for its columnwise direction. The implementation spans ~2000 lines of new source across a new
Confidence Score: 2/5The PR contains a method naming collision in In
Important Files Changed
Class Diagram%%{init: {'theme': 'neutral'}}%%
classDiagram
class HybridQuantizer {
+rowwise_quantizer: Quantizer
+columnwise_quantizer: Quantizer
+quantize_impl(tensor) HybridQuantizedTensor
+make_empty(shape) HybridQuantizedTensor
+update_quantized(src, dst) HybridQuantizedTensorStorage
+supports_only_rowwise_all_gather() bool
}
class HybridQuantizedTensor {
+_rowwise_storage: QuantizedTensorStorage
+_columnwise_storage: QuantizedTensorStorage
+fsdp_pre_all_gather()
+fsdp_post_all_gather()
+detach() HybridQuantizedTensor
+dequantize() Tensor
}
class HybridQuantizedTensorStorage {
+_rowwise_storage
+_columnwise_storage
+update_usage()
+clear()
+prepare_for_saving()
+restore_from_saved()
+dequantize()
}
class QuantizedTensorStorage {
+fsdp_buffer_fields()
+fsdp_extract_buffers()
+fsdp_assign_gathered()
}
class Float8TensorStorage {
+fsdp_buffer_fields() - direction-aware
+fsdp_assign_gathered() - resets _transpose_invalid
}
class MXFP8TensorStorage {
+fsdp_buffer_fields() - rowwise+colwise aware
+fsdp_extract_buffers() - strips scale padding
+fsdp_assign_gathered() - repads scale
}
class Float8BlockwiseQTensorStorage {
+fsdp_buffer_fields() - 2D only
+fsdp_extract_buffers() - transposes colwise to M-major
+fsdp_assign_gathered() - transposes back
}
HybridQuantizer --> HybridQuantizedTensor : produces
HybridQuantizedTensor --|> HybridQuantizedTensorStorage
HybridQuantizedTensorStorage --|> QuantizedTensorStorage
Float8TensorStorage --|> QuantizedTensorStorage
MXFP8TensorStorage --|> QuantizedTensorStorage
Float8BlockwiseQTensorStorage --|> QuantizedTensorStorage
HybridQuantizedTensorStorage "1" o-- "0..1" Float8TensorStorage : rowwise_storage
HybridQuantizedTensorStorage "1" o-- "0..1" Float8TensorStorage : columnwise_storage
Reviews (9): Last reviewed commit: "Enable FSDP2 hybrid protocol for Float8B..." | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
There was a problem hiding this comment.
This would not work under FSDP2.
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| outs = [ | ||
| Float8Tensor.make_like( | ||
| tensor, | ||
| data=split_tensor, | ||
| data_transpose=split_transpose_tensor, | ||
| shape=split_tensor.shape, | ||
| shape=( | ||
| split_tensor.shape | ||
| if split_tensor is not None | ||
| else split_transpose_tensor.shape | ||
| ), | ||
| ) | ||
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | ||
| ] |
There was a problem hiding this comment.
When
_data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.
| outs = [ | |
| Float8Tensor.make_like( | |
| tensor, | |
| data=split_tensor, | |
| data_transpose=split_transpose_tensor, | |
| shape=split_tensor.shape, | |
| shape=( | |
| split_tensor.shape | |
| if split_tensor is not None | |
| else split_transpose_tensor.shape | |
| ), | |
| ) | |
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | |
| ] | |
| outs = [ | |
| Float8Tensor.make_like( | |
| tensor, | |
| data=split_tensor, | |
| data_transpose=split_transpose_tensor, | |
| shape=( | |
| split_tensor.shape | |
| if split_tensor is not None | |
| # _transpose has shape [K, M/n] but the shard's nominal shape | |
| # is [M/n, K]. Recover the correct shard shape by reversing | |
| # the last two dims of the transposed piece. | |
| else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0]) | |
| ), | |
| ) | |
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | |
| ] |
| # DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories | ||
| # (lambdas, inner functions referencing captured state) are not picklable, | ||
| # so the qfactory must live at module scope. See | ||
| # ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``. |
There was a problem hiding this comment.
This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?
| for param in model.parameters(): | ||
| state = optimizer.state[param] | ||
| assert state["exp_avg"].dtype == torch.float32 | ||
| assert state["exp_avg_sq"].dtype == torch.float32 | ||
| if "master_param" in state: | ||
| assert state["master_param"].dtype == torch.float32 | ||
|
|
||
| assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" |
There was a problem hiding this comment.
That's not a very strict test, is there a way for us to do some numerical correctness comparisons?
There was a problem hiding this comment.
Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.
| # Quantized training may diverge from bf16, but should not be wildly different. | ||
| for step, (h_loss, b_loss) in enumerate(zip(hybrid_losses, bf16_losses)): | ||
| ratio = h_loss / max(b_loss, 1e-10) | ||
| assert 0.1 < ratio < 10.0, ( |
There was a problem hiding this comment.
Hmmm... What are the actual values there? Could we maybe set a seed or something and compare with some more reasonable tolerance?
There was a problem hiding this comment.
It is a few % from bf16, tightened the tolerance, please take a look
| "HybridMXFP8": dict(rtol=0.0, atol=0.0), | ||
| "HybridMixed_MXFP8_FP8": dict(rtol=0.0, atol=0.0), | ||
| } | ||
| tolerance = _TIGHT_TOLERANCE.get(hybrid_recipe_name, dict(rtol=1e-6, atol=1e-6)) |
There was a problem hiding this comment.
Hmm, again, what are the actual values here? I understand the general idea here that when the format is stateful then you will have some difference, but I don't think that 1e-6 would be the right tolerance if that difference actually happened. So if we do not actually test the case that would exhibit this issue then maybe it would be better to just set the tolerances to 0 in all cases to simplify the test code? This would then also be a clear point of failure if somebody added here a recipe that would not exhibit this same property.
There was a problem hiding this comment.
Right, I agree, just setting to 0.
| assert per_rank_out % 32 == 0, "MXFP8 data alignment precondition" | ||
| assert per_rank_out % 128 != 0, "Test precondition: shard must need scale padding" |
There was a problem hiding this comment.
Those error messages, since they are purely meant for the person changing the test itself, could be more descriptive in the error message.
| if hybrid_recipe_name == "HybridFloat8BlockScaling": | ||
| pytest.xfail( | ||
| "HybridFloat8BlockScaling: Float8BlockwiseQTensor sub-storage loses " | ||
| "quantized type through FSDP2 view(-1)." | ||
| ) | ||
|
|
||
| if hybrid_recipe_name == "HybridFP8CurrentScaling": | ||
| pytest.xfail( | ||
| "HybridFP8CurrentScaling: per-tensor _scale_inv is not preserved " | ||
| "through DCP's tensor-storage-byte serialization path. " | ||
| "HybridQuantizedTensor.__reduce_ex__ correctly round-trips through " | ||
| "pickle (verified by torch.save/torch.load), but DCP bypasses " | ||
| "pickle and serializes the tensor's storage bytes — the scalar " | ||
| "_scale_inv is not enumerated as a separate tensor leaf and gets " | ||
| "lost. Vanilla Float8CurrentScaling avoids this because per-tensor " | ||
| "scale lives in module.fp8_meta (saved as extra_state), not on " | ||
| "the tensor; hybrid uses per-sub-storage scales without that " | ||
| "mirror. Fix path: implement __tensor_flatten__/__tensor_unflatten__ " | ||
| "across the quantized tensor stack so DCP can serialize the " | ||
| "per-leaf tensor fields directly. Loaded model output diverges by " | ||
| "~5e-2." |
There was a problem hiding this comment.
Hmmm, do we intend to do something about that?
There was a problem hiding this comment.
Yes, a proper fix would touch all tensors, so let's do that in a separate PR, added a TODO at the description above.
|
|
||
| def _build_hybrid_model(num_layers, hybrid_recipe, use_meta_device=True): | ||
| """Build a model with quantized_model_init using a hybrid CustomRecipe.""" | ||
| ctx = te.quantized_model_init(enabled=True, recipe=hybrid_recipe) |
There was a problem hiding this comment.
This quite strange choice (it was also in the other test files) to separate the ctx definition from the usage. It is not a big deal (basically a nit), but it looks strange - it is better to have things close to the actual usage site if they are not too big.
| if hybrid_recipe_name == "HybridFloat8BlockScaling": | ||
| pytest.xfail( | ||
| "HybridFloat8BlockScaling: Float8BlockwiseQTensor sub-storage loses " | ||
| "quantized type through FSDP2 view(-1)." | ||
| ) |
There was a problem hiding this comment.
Any plans to address it? Is it a limitation of the underlying recipe? @vthumbe1503 any thoughts here?
There was a problem hiding this comment.
All of the QuantizedTensors have view(-1) implemented to return the dequantized output, so there should be something more going on in Float8BlockScaling that might be causing the test to fail here.
Also the view(-1) in FSDP2 is done to store a sharded tensor just for checkpointing logic and isnt used anywhere.
There was a problem hiding this comment.
Thanks for the insight @vthumbe1503
This is a hybrid-related bug.
After fixing the intermediate view-related bug, it turns out that Float8BlockwiseQTensorStorage does not implement the FSDP2 sub-storage protocol (fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered) required by hybrid all-gather.
Fixing it, WIP.
There was a problem hiding this comment.
| hybrid_avg = sum(hybrid_increments) / len(hybrid_increments) | ||
|
|
||
| excess_per_layer = hybrid_avg - bf16_avg | ||
| tolerance_per_layer = 50 * 1024 # 50 KiB |
There was a problem hiding this comment.
What is the base of that tolerance? What are the actual values?
There was a problem hiding this comment.
Added a comment that replies:
# Basis: forward growth is constant per layer (no accumulation) for both bf16 and
# hybrid; the excess is just hybrid's extra per-layer quantized buffers. Measured
# excess: ~3 KiB (FP8 current) / ~7 KiB (mixed MXFP8+FP8) / ~12 KiB (MXFP8). A
# leaked layer's quantized weights would be hundreds of KiB, so 50 KiB sits above
# the real per-layer overhead and well below a leak.
| ) | ||
|
|
||
| excess = hybrid_bwd_delta - bf16_bwd_delta | ||
| tolerance = 256 * 1024 # 256 KiB |
There was a problem hiding this comment.
Why is this one larger than the previous one?
There was a problem hiding this comment.
50 is a per-layer forward, whereas 256 is a whole-model backward + optimizer step
| loss = F.mse_loss(output, target) | ||
| loss.backward() | ||
| optimizer.step() | ||
| dist_print(f"Hybrid iteration {iteration} completed with loss {loss.item()}") |
There was a problem hiding this comment.
What is this test actually checking? There are no assertions here - if we only check if it does not crash then what is the value of this test vs the other ones?
There was a problem hiding this comment.
Ok, this was just a smoke test. Updated it to assert loss finiteness + strict monotonic decrease, hybrid-type preservation across the optimizer step, and FSDP2 all-gather correctness vs a manual fp32 dequant-then-allgather (the check test_distributed already had and this one was missing).
| loss = F.mse_loss(output, target) | ||
| loss.backward() | ||
| optimizer.step() | ||
| dist_print(f"Hybrid reshard_after_fwd iter {iteration}, loss {loss.item():.4f}") |
| * ``te.Linear`` column-parallel and row-parallel, with and without | ||
| sequence parallelism. | ||
| * ``te.LayerNormLinear`` column-parallel with sequence parallelism — | ||
| the quantized-AG path that currently unfuses LN+quantize for | ||
| ``HybridQuantizer``. | ||
| * ``te.TransformerLayer`` with ``set_parallel_mode=True`` and SP on — | ||
| integration test hitting LayerNormLinear + DPA + LayerNormMLP + row- | ||
| parallel output projection in one shot. |
There was a problem hiding this comment.
Considering that Transformer layer gives basically everything, what is the value of the other tests? And if there is value in the other tests, then why don't we check the LayerNormMLP on its own?
There was a problem hiding this comment.
Other tests are complementary, not redundant. Transformer layer test is the broad smoke. Standalone tests have additional value (grad checks + extra configs). Following this, adding LayerNormMLP. Also added a bitwise hybrid-vs-vanilla equivalence test on te.Linear.
| numerical signal is clean. Cross-format hybrid adds independent | ||
| numerical variation unrelated to TP/SP and is covered by single-GPU | ||
| tests already. |
There was a problem hiding this comment.
I'm not sure if I agree with this assessment. Cross format is actually the only case where you need to be careful about the allgather cases being different in forward and backward and allgather touches the comm-gemm overlap that would also be potentially affected (e.g. due to wrong buffer sizes taken from forward recipe being used in backward).
Also, a general comment - hybrid recipes that do forward quantized/backward unquantized and vice versa would be very useful to test as well.
There was a problem hiding this comment.
Good point, you are right, AG is format dependent. Adding MXFP8 forward / NVFP4 backward. and updating the docstring.
There was a problem hiding this comment.
Also, a general comment - hybrid recipes that do forward quantized/backward unquantized and vice versa would be very useful to test as well.
Looking into it.
| Tolerances match upstream ``run_numerics.py`` per-format settings (see | ||
| ``_get_tolerances``); they're loose enough to absorb the amax-reduction | ||
| and stochastic numerical asymmetries inherent to distributed FP8, tight | ||
| enough to catch a silent BF16 fallback on the hybrid sub-storage path. |
There was a problem hiding this comment.
Shouldn't the tolerances be effectively 0 if you are doing the non-actually-hybrid recipes only? Since you should be comparing the same underlying implementations.
There was a problem hiding this comment.
These tolerances are for the distributed vs single-node comparison, not hybrid vs vanilla. I will reword this comment to make the two comparisons distinct.
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if is_linear and role.tensor_type in ("grad_output", "grad_input"): | ||
| return _make_mxfp8_quantizer(fp8_dtype=tex.DType.kFloat8E5M2) |
There was a problem hiding this comment.
E5M2 here is not correct. In general this should just be a single line to return the mxfp8 quantizer for all cases.
| """Default NVFP4Quantizer: no RHT, no stochastic rounding, no 2D | ||
| scaling — matches upstream ``run_numerics.py::nvfp4_vanilla()`` which | ||
| strips the recipe to bare 1D block scaling for distributed TP | ||
| fairness. Same-format hybrid NVFP4 has no E5M2 variant (NVFP4 is a | ||
| single format family — E2M1 only), so grad roles reuse the same | ||
| NVFP4 quantizer.""" |
There was a problem hiding this comment.
Why don't we want to check the full recipe here?
There was a problem hiding this comment.
Switched to the full recipe except 1D for weights, will enabled after #3027 merge
| is_linear = role is not None and role.module_type in ("linear", "grouped_linear") | ||
| if is_linear and role.tensor_type in ("input", "weight", "output"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_nvfp4_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if is_linear and role.tensor_type in ("grad_output", "grad_input"): | ||
| return _make_nvfp4_quantizer() |
There was a problem hiding this comment.
As written those lines are not needed at all. They would be needed if you did the full recipe.
There was a problem hiding this comment.
Switched to the full recipe
| # quantization (rowwise and columnwise quantizers run independently, so | ||
| # their outputs may differ by ~1 ULP from a single fused-quantize path | ||
| # in edge cases). |
There was a problem hiding this comment.
That does not sound like a good thing if it actually happens in practice - the quantization only should not be affected if you do both at the same time vs one at a time -> the input and the algorithm is the same in both cases. Fusion with the activations could maybe give slightly different results, but I would still like to get an explanation of why that would be.
There was a problem hiding this comment.
You are right. If the algorithm is the same, we are indeed getting identical results. New bitwise linear_vs_vanilla test confirms this. The only place where two pass and fused differ is NVFP4 with RHT + SR. This activates a separate columnwise RNG (need_separate_columnwise_rng), and RNG stream consumed differently. see comment in _backward_not_bitwise_comparable(). Removed the comment.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Description
Hybrid (per-direction) quantization. Functional.
C++ optimizations (fusions, etc.) will come in the next PRs.
TODO:
__tensor_flatten__/__tensor_unflatten__across the tensor stack —torch.save/loadand FP32-master paths unaffected; covered by the test_hybrid_dcp_output_parity xfail.Integration
Ecosystem integration (all functional, unit-tested):
Megatron-LM integration status:
--fp{4,8}-param-gather+ dist opt (persistent low-precision params viaquantized_model_init+ sharded-master FP32 → quantized cast viaquantize_master_weights.)- [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
including same-format, cross-format Float8, single-direction)
- [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by
quantize_master_weights; unblocker is TE-side cast-helper / kernel.--fp{4,8}-param-gather(fix private attribute access)--fp{4,8}-param-gather- [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
- [TODO] NVFP4 sub-storage FSDP2 hooks
_hybrid_split_quantizeunder Megatron MoE)Review
Total diff +9000
New hybrid source (
hybrid_tensor.py,hybrid_tensor_storage.py) ~1000Adjacent modifications ~1000
Tests are the rest
Surface to review is ~2000 lines
Suggested reading order
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: