diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index ae98c327b45..0175d0deaf6 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -14,7 +14,7 @@ import torch import torch.nn.functional as F -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm, RMSNormCoreML # noqa: F401 from executorch.examples.models.llama.rope import ( hf_apply_rotary_emb, @@ -109,65 +109,6 @@ def __post_init__(self): self.head_dim = self.dim // self.n_heads -class CoreMLRMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - # CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable - # We instead use (x * sqrt(n)) / norm(x, dim=-1) - # Using torch.norm and preserving this op in CoreML improves stability - # Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator - # In future, we want to add CoreML support for the functional RMSNorm op - # We have yet to do large scale evaluations on the numeric stability of this solution, but note that - # it appears better than what exists currently (removing FP32 casts and using FP16) - rms_norm_eps0 = ( - x - * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) - * torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True)) - ) - return rms_norm_eps0 - - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x) - return output * self.weight - - class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index e424ee0361a..e1e47ea02cf 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -57,6 +57,54 @@ def __init__(self, dim: int, eps: float = 1e-6): self.weight.requires_grad = False +class RMSNormCoreML(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is + preserved in the CoreML graph for numerical stability. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): Floor on the L2-norm denominator + (`clamp_min(‖x‖₂, √(dim·eps))`). Prevents `0/0 = NaN` on + zero-padded positions and matches standard RMSNorm's + `rsqrt(mean(x²) + eps)` semantics on a zero input. Must be > 0. + + Attributes: + eps (float): Floor coefficient consumed by `_norm`. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + assert eps > 0, ( + "RMSNormCoreML requires eps > 0; eps=0 collapses the denominator " + "floor and produces NaN on zero-padded positions" + ) + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + # Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions + # (chunked prefill in StaticAttentionIOManager pads each chunk to + # input_len with zeros). Use sqrt(dim * eps) so the floor matches + # standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is + # large enough to survive fp16 (1e-6 alone underflows in fp16). + floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype)) + norm_val = torch.clamp_min( + torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val + ) + rms_norm_eps0 = ( + x + * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) + * torch.reciprocal(norm_val) + ) + return rms_norm_eps0 + + def forward(self, x): + output = self._norm(x) + return output * self.weight + + class RMSNormWithInputScale(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() @@ -83,3 +131,37 @@ def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tens hidden_states = self.weight * hidden_states.to(input_dtype) hidden_states = hidden_states * F.silu(gate.to(torch.float32)) return hidden_states.to(input_dtype) + + +def replace_rms_norm_for_coreml_(model: torch.nn.Module) -> torch.nn.Module: + """In-place: walk `model` and swap every RMSNorm-family module for RMSNormCoreML. + + Mirrors the post-construction transform pattern used by torchao's + `quantize_(model, config)`: instead of threading a `use_coreml_norm` flag + through every norm construction site, build the model with the standard + norms and then call this once before CoreML export. Trained scale weights + are preserved. + + Swaps these classes (everything else is left alone): + * `RMSNorm` (this module) + * `ScalelessRMSNorm` (this module — no-op weight) + * `torch.nn.RMSNorm` (used for affine q_norm/k_norm in StaticAttention) + """ + for name, mod in list(model.named_modules()): + if not isinstance(mod, (RMSNorm, ScalelessRMSNorm, torch.nn.RMSNorm)): + continue + # All three carry the normalized dim either as `dim` or in `normalized_shape[-1]`. + dim = getattr(mod, "dim", None) or mod.normalized_shape[-1] + eps = getattr(mod, "eps", 1e-6) or 1e-6 + new = RMSNormCoreML(dim, eps=eps) + # Preserve trained scale (no-op for ScalelessRMSNorm). + if getattr(mod, "weight", None) is not None: + new.weight = mod.weight + # Locate parent module via the dotted name and rebind the attribute. + if "." in name: + parent_name, attr = name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + else: + parent, attr = model, name + setattr(parent, attr, new) + return model