Skip to content

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644

Merged
ptrendx merged 63 commits into
NVIDIA:mainfrom
zianglih:keep-bwd
Apr 7, 2026
Merged

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644
ptrendx merged 63 commits into
NVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih

@zianglih zianglih commented Feb 3, 2026

Copy link
Copy Markdown
Contributor

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized env var:

  • Not set: existing default quantization behavior
  • high_precision: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
    • image
  • dequantized: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value
    • image

The movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Feb 3, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds NVTE_BACKWARD_OVERRIDE=high_precision|dequantized support, enabling high-precision backward passes (wgrad & dgrad) in combination with quantized fprop. The high_precision mode saves original unquantized activations/weights; dequantized mode saves quantized fprop tensors and dequantizes them before backward GEMMs to avoid gradient quantization errors during RL fine-tuning.

LayerNormMLP and DelayedScaling recipe intentionally reject NVTE_BACKWARD_OVERRIDE with clear assertion messages and guidance to use LayerNormLinear + Linear as an alternative. All other previously flagged concerns (duplicate recipe fields, recipe None crash, unnecessary saved tensors) appear resolved in this revision.

Confidence Score: 5/5

Safe to merge — all previously flagged blocking issues are resolved; remaining findings are P2 style suggestions.

All previously raised P0/P1 concerns (duplicate recipe fields, recipe None crash, LayerNormMLP assertion message quality, unnecessary saved tensors, DelayedScaling interaction) are addressed in this revision. The feature is guarded by explicit assertions with clear error messages for unsupported combinations. The only new findings are a defensive getattr suggestion in fuser.py and a minor asymmetry in empty-tensor guards for MXFP8 storage, both P2. Comprehensive test coverage was added.

transformer_engine/pytorch/ops/fuser.py (minor: direct attribute access on recipe.backward_override), transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py (minor: asymmetric empty-tensor guard in dequantize vs _FromMXFP8Func.forward)

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Adds backward_override field to all recipe dataclasses. DelayedScaling asserts backward_override is None with a clear error message. Previously flagged duplicate fields in Float8CurrentScaling are gone.
transformer_engine/pytorch/module/linear.py Adds backward_override detection; sets save_original_input=True for high_precision mode and saves unquantized/quantized operands accordingly. Properly asserts against Float8Quantizer (DelayedScaling) when save_original_input is true.
transformer_engine/pytorch/ops/basic/basic_linear.py Functional forward sets columnwise=False for quantized operands when backward_override is not None. op_forward chooses between saving original (high_precision) or quantized (dequantized) tensors. Backward pass dispatches dequantization accordingly.
transformer_engine/pytorch/module/layernorm_linear.py Correctly implements both high_precision and dequantized modes. Disables optimize_for_gemm for MXFP8/NVFP4 dequantized mode. save_for_backward order matches test expectations.
transformer_engine/pytorch/module/layernorm_mlp.py Explicitly asserts backward_override is None with a clear error message directing users to LayerNormLinear + Linear. Intentional documented limitation.
transformer_engine/pytorch/module/grouped_linear.py Adds backward_override support; when set, disables FP8/UB/debug context in backward. Correctly handles both override modes in wgrad/dgrad GEMMs.
transformer_engine/pytorch/ops/fuser.py Adds backward_override to the fusion cache key so fused op graphs are rebuilt when the override changes. Direct attribute access recipe.backward_override could AttributeError for custom Recipe subclasses not defining the field.
transformer_engine/pytorch/ops/basic/quantize.py Disables backward quantization when recipe.backward_override is not None. get_fp8_recipe() is guarded by the fp8_enabled check and always returns a recipe, so no None crash.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds empty-tensor early returns in both _FromMXFP8Func.forward (rowwise and columnwise) and MXFP8TensorStorage.dequantize (rowwise only). Asymmetry is functionally safe but inconsistent.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Adds empty-tensor early returns consistent with MXFP8 pattern. Functionally safe for the dequantized backward use case.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds early return for empty rowwise tensor in dequantize, preventing errors when empty-token grouped-linear chunks are encountered in dequantized mode.
tests/pytorch/test_backward_override.py New comprehensive test file covering both override modes across all recipe types, module types, shapes, and fused op patterns. Layout invariant checks guard against hidden requantization during backward.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Forward Pass - quantized fprop] --> B{NVTE_BACKWARD_OVERRIDE}
    B -->|None| C[Default: save rowwise+columnwise quantized tensors]
    B -->|high_precision| D[Save original unquantized input and weight]
    B -->|dequantized| E[Save rowwise-only quantized tensors]
    C --> F[Backward: quantized dgrad and wgrad GEMMs]
    D --> G[Backward: high-precision dgrad and wgrad using original fp16/bf16/fp32 operands]
    E --> H[Backward: dequantize saved tensors then high-precision GEMMs]
    subgraph Supported
        L[Linear]
        M[LayerNormLinear]
        N[GroupedLinear]
        O[ops.Linear / fused ops]
    end
    subgraph Unsupported - assertion error with clear message
        P[LayerNormMLP]
        Q[DelayedScaling recipe]
    end
Loading

Reviews (43): Last reviewed commit: "Merge branch 'main' into keep-bwd" | Re-trigger Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih

zianglih commented Feb 3, 2026

Copy link
Copy Markdown
Contributor Author

I'll work on potential unit test breakage.

Comment thread transformer_engine/pytorch/ops/fuser.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_linear.py Outdated

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/quantization.py Outdated
Comment thread transformer_engine/pytorch/module/linear.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_linear.py
Comment thread transformer_engine/pytorch/module/layernorm_mlp.py

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/module/linear.py Outdated

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/quantization.py Outdated

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/module/layernorm_mlp.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_linear.py Outdated

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/module/linear.py Outdated
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

Comment thread transformer_engine/pytorch/module/layernorm_linear.py Outdated
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

Comment thread transformer_engine/pytorch/module/layernorm_linear.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_linear.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_mlp.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_mlp.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_mlp.py Outdated
Comment thread transformer_engine/pytorch/module/layernorm_mlp.py Outdated
Signed-off-by: Ziang Li <ziangli@umich.edu>
…zed`

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih changed the title Add NVTE_BACKWARD_MODE=default|unquant|dequant Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized Mar 14, 2026
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih

zianglih commented Mar 16, 2026

Copy link
Copy Markdown
Contributor Author

using "dequantized" in bwd still does not preserve the chain rule 100%, as the quantization in fwd and bwd happens along different dims

@victordion I think you are describing the default TE 1d recipe or requantized behavior.

@victordion

Copy link
Copy Markdown

using "dequantized" in bwd still does not preserve the chain rule 100%, as the quantization in fwd and bwd happens along different dims

@victordion I think you are describing the default TE 1d recipe or requantized behavior.

Right. My mistake. My mental model assumed there is requantize happening. Thanks for responding!

@zianglih zianglih requested review from ksivaman and zhongbozhu March 17, 2026 05:42
@zianglih

Copy link
Copy Markdown
Contributor Author

Regarding the env var design, since this feature is mainly used by RL, there has to be a way for the user to directly override the bwd behavior in RL framework instead of plumbing all the way through Megatron.

@ksivaman

Copy link
Copy Markdown
Member

/te-ci L0 L1

@zianglih

Copy link
Copy Markdown
Contributor Author

All pytorch ci passed.

Some failed jax tests are due to FileExistsError: [Errno 17] File exists: '/logs' .

@zhongbozhu

Copy link
Copy Markdown
Collaborator

/te-ci L0 L1

@zhongbozhu zhongbozhu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@ksivaman ksivaman dismissed their stale review March 31, 2026 19:25

Unblocking

@zhongbozhu

Copy link
Copy Markdown
Collaborator

/te-ci L0 L1

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx

ptrendx commented Apr 3, 2026

Copy link
Copy Markdown
Member

/te-ci L0 L1

@zianglih

zianglih commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

Failed JAX ci is unrelated to this PR:

B200:

../../tests/jax/test_permutation.py::TestHighLevelPermutationAPI::test_sort_chunks_by_index[dtype_float32-8-4096-1280] FAILED

L40:

../../tests/jax/test_permutation.py::TestHighLevelPermutationAPI::test_sort_chunks_by_index[dtype_float32-8-4096-1280] FAILED

@ptrendx ptrendx merged commit fdf9fb1 into NVIDIA:main Apr 7, 2026
46 of 52 checks passed
@ptrendx

ptrendx commented Apr 7, 2026

Copy link
Copy Markdown
Member

Merged, thank you for your contribution @zianglih!

@YigongQin YigongQin mentioned this pull request Apr 21, 2026
13 tasks
@zianglih zianglih mentioned this pull request May 9, 2026
13 tasks
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
* Add NVTE_KEEP_BACKWARD_UNQUANTIZED

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Disable ub and clean up

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Drop fuser changes

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Replace use_quantized_bwd with use_fp8_bwd

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Ignore keep_backward_unquantized if delayed scaling

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Add back missing ctx.debug

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Refactor changes under fused

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Clean up

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Refactor high-precision overwrite if keep_backward_unquantized

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Clean up

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Drop redundant fp8_recipe_bwd

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Drop redundant ub changes

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Drop more redundant ub changes

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Drop redundant delayed scaling changes

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Drop unneeded backwards_needs_fc1_input

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Drop and disallow LayerNormMLP implementation

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move interface changes to recipe

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move ub overrides to fwd

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Remove duplication

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Simplify use_fp8_bwd logic in bwd

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Set grad quantizers to none if keep bwd unquantized

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Drop delayed scaling change

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Simplify env var logic

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Move validation check to recipe

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Simplify effective_enabled

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix inverted assertion logic

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Simplify changes under ops

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Simplify ctx.keep_backward_unquantized

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix missing attribute

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Add unit tests

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix bias errors in unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add more shapes to unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Refator interface to `NVTE_BACKWARD_MODE=default|unquant|dequant`

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix override and clean up

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Clean up unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Clean up unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Override `ctx.reduce_and_update_bwd_fp8_tensors = False`

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Expand unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Add `test_backward_mode_memory_peak_report`

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Expand test coverage and fix

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Use `numel()`

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Refactor unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix grouped linear to override `*_quantizers` instead of `*_quantizer`

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Only save input/weight when `*_requires_grad` on unquant mode

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix Blackwell debug ci

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix sm89 and sm90 tests

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Fix unquant mode memory saving

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Refactor interface to `NVTE_BACKWARD_OVERRIDE=high_precision|dequantized`

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Rename unit test

Signed-off-by: Ziang Li <ziangli@umich.edu>

* Simplify env var parsing

Signed-off-by: Ziang Li <ziangli@umich.edu>

---------

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Przemek Tredak <ptredak@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants