diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index e1e47ea02cf..b1daa739b7a 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -44,9 +44,9 @@ def forward(self, x): class ScalelessRMSNorm(torch.nn.RMSNorm): """RMSNorm with weight hardcoded to ones and not trainable. - Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a - torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN - instead of being expressed as a hand-rolled decomposition. + Subclasses torch.nn.RMSNorm so backends (QNN) see a proper RMSNorm op for + lowering, but overrides forward with the explicit fp32 decomposition to + stay numerically identical to the rlformers reference during eager execution. """ def __init__(self, dim: int, eps: float = 1e-6): @@ -56,6 +56,15 @@ def __init__(self, dim: int, eps: float = 1e-6): self.weight.fill_(1.0) self.weight.requires_grad = False + def forward(self, x): + if torch.compiler.is_compiling(): + return super().forward(x) + x_float = x.float() + return ( + x_float + * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps) + ).type_as(x) + class RMSNormCoreML(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6):