Restore ScalelessRMSNorm hand-rolled forward for eager numerical parity#19654
Restore ScalelessRMSNorm hand-rolled forward for eager numerical parity#19654billmguo wants to merge 1 commit into
Conversation
Summary: D104258950 changed `ScalelessRMSNorm` from a hand-rolled fp32 decomposition to a `torch.nn.RMSNorm` subclass so that QNN and other backends see a proper RMSNorm op for lowering. However, removing the custom `forward` meant eager execution now uses `torch.nn.RMSNorm`'s fused CUDA kernel, which has different internal precision handling than the hand-rolled `x.float() * rsqrt(mean(x^2) + eps)` decomposition used by the rlformers reference model. This caused both `test_llm_backbone_correctness_cuda` and `test_llm_backbone_correctness_decode` to fail: - **fp32 case**: SNR dropped from `inf` to 67-85 dB (same decoded text, different logits) - **quantized case**: SNR dropped to 1-35 dB with negative per-step values and divergent decoded text, because the precision difference was amplified by quantization noise The fix restores the original hand-rolled `forward` override on `ScalelessRMSNorm` while keeping `torch.nn.RMSNorm` as the base class. A `torch.compiler.is_compiling()` guard ensures that during `torch.export` (for QNN, XNNPACK, or any backend), the fused `torch.nn.RMSNorm` op is used instead — preserving the export-path fix from D104258950. Differential Revision: D105593738
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19654
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Unclassified FailuresAs of commit c93f866 with merge base 7c495fa ( NEW FAILURE - The following job has failed:
UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@billmguo has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105593738. |
This PR needs a
|
Summary:
D104258950 changed
ScalelessRMSNormfrom a hand-rolled fp32 decomposition to atorch.nn.RMSNormsubclass so that QNN and other backends see a proper RMSNorm op for lowering. However, removing the customforwardmeant eager execution now usestorch.nn.RMSNorm's fused CUDA kernel, which has different internal precision handling than the hand-rolledx.float() * rsqrt(mean(x^2) + eps)decomposition used by the rlformers reference model.This caused both
test_llm_backbone_correctness_cudaandtest_llm_backbone_correctness_decodeto fail:infto 67-85 dB (same decoded text, different logits)The fix restores the original hand-rolled
forwardoverride onScalelessRMSNormwhile keepingtorch.nn.RMSNormas the base class. Atorch.compiler.is_compiling()guard ensures that duringtorch.export(for QNN, XNNPACK, or any backend), the fusedtorch.nn.RMSNormop is used instead — preserving the export-path fix from D104258950.Differential Revision: D105593738