diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 7fc111e2c34..5b3a2d03874 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -154,21 +154,32 @@ def hf_precompute_freqs_cis( # Partial rotary embeddings. dim = int(dim * partial_rotary_factor) - # Compute the RoPE table in fp64 to minimize ULP-level drift; cast to fp32 - # once at the end. Phi-4 Mini's narrow decode-time logit margins make the - # exported model sensitive to 1-ULP differences in freqs_cos / freqs_sin - # under sampling, especially on the Vulkan delegate. + # fp64 precompute is required whenever cos/sin will be scaled by a + # non-trivial attention_factor (LongRoPE on Phi-3 / Phi-4 family). There, + # fp32 ULP-level rounding in the table is load-bearing on Vulkan under + # sampling -- a fp32-only regression manifests as decode-time n-gram + # looping, not a unit-test red. For vanilla HF RoPE, fp32 throughout + # produces cos/sin tables bit-identical to the non-HF precompute_freqs_cis + # path, which the static-attention vs MHA parity tests rely on. + # + # If you add a new model that needs cos/sin scaling but does not set + # short_factor / long_factor / attention_factor, extend the gate below. + longrope_active = (short_factor is not None) or (long_factor is not None) + needs_fp64 = longrope_active or ( + attention_factor is not None and attention_factor != 1.0 + ) + compute_dtype = torch.float64 if needs_fp64 else torch.float32 + inv_freq = 1.0 / ( theta ** ( - torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(torch.float64) + torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(compute_dtype) / dim ) ) # LongRoPE: divide inv_freq element-wise by short_factor or long_factor. # Selection mirrors HF: long_factor when seq_len > original_max_position_embeddings. - longrope_active = (short_factor is not None) or (long_factor is not None) if longrope_active: chosen = ( long_factor @@ -178,7 +189,7 @@ def hf_precompute_freqs_cis( if chosen is None: # Fall back to whichever factor was provided. chosen = short_factor if long_factor is None else long_factor - ext_factors = torch.tensor(chosen, dtype=torch.float64, device=device) + ext_factors = torch.tensor(chosen, dtype=compute_dtype, device=device) assert ext_factors.numel() == inv_freq.numel(), ( f"LongRoPE factor length {ext_factors.numel()} must equal dim/2 " f"({inv_freq.numel()})" @@ -200,8 +211,8 @@ def hf_precompute_freqs_cis( ) # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. - t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(torch.float64) - freqs = torch.outer(t, inv_freq).to(torch.float64) # pyre-ignore + t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(compute_dtype) + freqs = torch.outer(t, inv_freq).to(compute_dtype) # pyre-ignore emb = torch.cat((freqs, freqs), dim=-1) cos_tab = torch.cos(emb) sin_tab = torch.sin(emb)