Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 4, 2025

📄 57% (0.57x) speedup for MixtralRotaryEmbedding.forward in src/transformers/models/mixtral/modeling_mixtral.py

⏱️ Runtime : 5.33 milliseconds 3.38 milliseconds (best of 118 runs)

📝 Explanation and details

The optimized code achieves a 57% speedup by eliminating expensive tensor operations in the forward method of MixtralRotaryEmbedding. The key optimizations are:

1. Broadcasting instead of expand + matmul: The original code used .expand() to replicate tensors across dimensions, then performed matrix multiplication. The optimized version uses direct broadcasting (position_ids[:, None, :] * inv_freq[None, :, None]), which is more memory-efficient and computationally faster since PyTorch can optimize element-wise operations better than general matrix multiplication.

2. Eliminated redundant type conversions: The original code called .float() multiple times on the same tensors. The optimization moves all type casting to the beginning, converting inv_freq and position_ids to float32 once and reusing them.

3. Removed unnecessary autocast context: Since the computation is already done in float32, the torch.autocast wrapper adds overhead without benefit. The optimized version computes cos/sin directly and converts to the target dtype only at the end.

4. Simplified tensor reshaping: Instead of complex expand operations followed by transpose, the optimization uses simpler concatenation and a single transpose at the end.

The test results show consistent 70-80% speedup across various input sizes, with the optimization being particularly effective for:

  • Small to medium batch sizes (1-32 batches)
  • Sequence lengths up to 256 tokens
  • Different data types (float16, float32, bfloat16)

The speedup is most pronounced in typical transformer inference scenarios where these operations are called frequently with moderate-sized tensors.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 76 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch
from transformers.models.mixtral.modeling_mixtral import MixtralRotaryEmbedding


# function to test
class DummyConfig:
    def __init__(
        self,
        max_position_embeddings=512,
        rope_parameters=None,
        head_dim=None,
        hidden_size=128,
        num_attention_heads=8,
    ):
        self.max_position_embeddings = max_position_embeddings
        self.rope_parameters = rope_parameters or {"rope_type": "default", "rope_theta": 10000.0}
        self.head_dim = head_dim
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
from transformers.models.mixtral.modeling_mixtral import MixtralRotaryEmbedding

# unit tests

# -------------------- BASIC TEST CASES --------------------

def test_forward_basic_shape_and_type():
    """Test that forward returns tensors of correct shape and type for basic input."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 2
    seq_len = 4
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids) # 111μs -> 63.5μs (76.0% faster)

def test_forward_values_range():
    """Test that cos values are in [-1, 1] and sin values are in [-1, 1] for basic input."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 1
    seq_len = 8
    dim = config.hidden_size // config.num_attention_heads
    x = torch.randn((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids) # 108μs -> 60.4μs (79.9% faster)

def test_forward_zero_position_ids():
    """Test that position_ids of all zeros produces cos=1 and sin=0 everywhere."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 3
    seq_len = 5
    dim = config.hidden_size // config.num_attention_heads
    x = torch.ones((batch, seq_len, dim))
    position_ids = torch.zeros((batch, seq_len), dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids) # 118μs -> 66.2μs (79.3% faster)

def test_forward_nonzero_position_ids():
    """Test that nonzero position_ids produce cos/sin not all 1/0."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 2
    seq_len = 6
    dim = config.hidden_size // config.num_attention_heads
    x = torch.ones((batch, seq_len, dim))
    position_ids = torch.arange(1, seq_len+1).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids) # 111μs -> 62.5μs (78.9% faster)

# -------------------- EDGE TEST CASES --------------------

def test_forward_empty_batch():
    """Test that forward handles batch size zero."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 0
    seq_len = 5
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.zeros((batch, seq_len), dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids) # 96.2μs -> 58.3μs (64.8% faster)

def test_forward_empty_sequence():
    """Test that forward handles sequence length zero."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 2
    seq_len = 0
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.zeros((batch, seq_len), dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids) # 100μs -> 57.8μs (73.8% faster)

def test_forward_single_dim():
    """Test with head_dim=1 (minimum allowed)."""
    config = DummyConfig(head_dim=1, hidden_size=8, num_attention_heads=8)
    rotary = MixtralRotaryEmbedding(config)
    batch = 1
    seq_len = 3
    dim = config.head_dim
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids) # 100μs -> 58.3μs (72.9% faster)

def test_forward_large_theta():
    """Test with very large rope_theta parameter."""
    config = DummyConfig(rope_parameters={"rope_type": "default", "rope_theta": 1e12})
    rotary = MixtralRotaryEmbedding(config)
    batch = 1
    seq_len = 4
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids) # 106μs -> 59.0μs (80.2% faster)

def test_forward_negative_position_ids():
    """Test that negative position_ids are handled (should not error)."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 2
    seq_len = 3
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.tensor([[-1, -2, -3], [-4, -5, -6]])
    cos, sin = rotary.forward(x, position_ids) # 113μs -> 62.5μs (82.1% faster)

def test_forward_dtype_float16():
    """Test that forward works with float16 input."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 1
    seq_len = 4
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim), dtype=torch.float16)
    position_ids = torch.arange(seq_len).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids) # 109μs -> 63.5μs (72.4% faster)

def test_forward_dtype_bfloat16():
    """Test that forward works with bfloat16 input."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 1
    seq_len = 4
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim), dtype=torch.bfloat16)
    position_ids = torch.arange(seq_len).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids) # 109μs -> 64.0μs (71.4% faster)

def test_forward_device_cpu_vs_cuda():
    """Test that forward works on both CPU and CUDA (if available)."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 1
    seq_len = 2
    dim = config.hidden_size // config.num_attention_heads
    position_ids = torch.arange(seq_len).unsqueeze(0)
    x_cpu = torch.zeros((batch, seq_len, dim))
    cos_cpu, sin_cpu = rotary.forward(x_cpu, position_ids) # 107μs -> 59.3μs (80.8% faster)
    if torch.cuda.is_available():
        rotary_cuda = MixtralRotaryEmbedding(config, device="cuda")
        x_cuda = torch.zeros((batch, seq_len, dim)).cuda()
        position_ids_cuda = position_ids.cuda()
        cos_cuda, sin_cuda = rotary_cuda.forward(x_cuda, position_ids_cuda)


def test_forward_large_batch_and_seq():
    """Test forward with large batch and sequence, but under 100MB."""
    config = DummyConfig(hidden_size=128, num_attention_heads=8)
    rotary = MixtralRotaryEmbedding(config)
    batch = 32
    seq_len = 128
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids) # 248μs -> 165μs (50.0% faster)

def test_forward_max_dim_under_100mb():
    """Test with maximum dim, batch, seq allowed under 100MB."""
    # Each float32 is 4 bytes.
    # Let's try batch=16, seq_len=256, dim=64 => 16*256*64*4 = 1,048,576 bytes (~1MB)
    config = DummyConfig(hidden_size=512, num_attention_heads=8)
    rotary = MixtralRotaryEmbedding(config)
    batch = 16
    seq_len = 256
    dim = config.hidden_size // config.num_attention_heads
    x = torch.zeros((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids) # 629μs -> 502μs (25.2% faster)

def test_forward_randomized_large_input():
    """Test with randomized input and position_ids for large input."""
    config = DummyConfig(hidden_size=256, num_attention_heads=8)
    rotary = MixtralRotaryEmbedding(config)
    batch = 64
    seq_len = 32
    dim = config.hidden_size // config.num_attention_heads
    x = torch.randn((batch, seq_len, dim))
    position_ids = torch.randint(0, 100, (batch, seq_len))
    cos, sin = rotary.forward(x, position_ids) # 246μs -> 171μs (44.0% faster)

def test_forward_multiple_calls_consistency():
    """Test that multiple calls with same input produce same output."""
    config = DummyConfig()
    rotary = MixtralRotaryEmbedding(config)
    batch = 8
    seq_len = 16
    dim = config.hidden_size // config.num_attention_heads
    x = torch.randn((batch, seq_len, dim))
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    cos1, sin1 = rotary.forward(x, position_ids) # 119μs -> 63.4μs (88.5% faster)
    cos2, sin2 = rotary.forward(x, position_ids) # 58.9μs -> 33.3μs (76.7% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
import torch
from transformers.models.mixtral.modeling_mixtral import MixtralRotaryEmbedding


# function to test (copied from above)
class DummyConfig:
    def __init__(
        self,
        max_position_embeddings=2048,
        hidden_size=8,
        num_attention_heads=2,
        rope_parameters=None,
        head_dim=None,
    ):
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.rope_parameters = rope_parameters or {"rope_type": "default", "rope_theta": 10000.0}
        self.head_dim = head_dim
from transformers.models.mixtral.modeling_mixtral import MixtralRotaryEmbedding

# unit tests

# ----------- BASIC TEST CASES -----------

def test_forward_basic_shape_and_types():
    # Test that output shapes and types are correct for basic input
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    # x: (batch_size, seq_len, dim)
    x = torch.zeros(2, 4, 4, dtype=torch.float32)
    # position_ids: (batch_size, seq_len)
    position_ids = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
    cos, sin = rotary.forward(x, position_ids) # 110μs -> 62.4μs (77.8% faster)

def test_forward_basic_values():
    # Test that cos/sin values are correct for position_ids=0
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 4)
    position_ids = torch.tensor([[0, 0]])
    cos, sin = rotary.forward(x, position_ids) # 107μs -> 62.9μs (71.6% faster)

def test_forward_dtype_preservation():
    # Test that output dtype matches input dtype
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    for dtype in [torch.float32, torch.float16]:
        x = torch.zeros(1, 2, 4, dtype=dtype)
        position_ids = torch.tensor([[0, 1]])
        cos, sin = rotary.forward(x, position_ids) # 162μs -> 96.0μs (69.3% faster)

def test_forward_batch_size_one():
    # Test with batch size 1
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 3, 4)
    position_ids = torch.tensor([[0, 1, 2]])
    cos, sin = rotary.forward(x, position_ids) # 106μs -> 60.3μs (76.4% faster)

def test_forward_single_token():
    # Test with sequence length 1
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(2, 1, 4)
    position_ids = torch.tensor([[0], [1]])
    cos, sin = rotary.forward(x, position_ids) # 106μs -> 61.7μs (72.9% faster)

# ----------- EDGE TEST CASES -----------

def test_forward_minimal_dim():
    # Test with minimal dimension (head_dim=2)
    config = DummyConfig(hidden_size=4, num_attention_heads=2, head_dim=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 2)
    position_ids = torch.tensor([[0, 1]])
    cos, sin = rotary.forward(x, position_ids) # 103μs -> 62.0μs (66.7% faster)

def test_forward_negative_position_ids():
    # Test with negative position ids (should still compute cos/sin)
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 4)
    position_ids = torch.tensor([[0, -1]])
    cos, sin = rotary.forward(x, position_ids) # 109μs -> 62.3μs (75.4% faster)

def test_forward_large_position_ids():
    # Test with very large position ids
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 4)
    position_ids = torch.tensor([[100000, 200000]])
    cos, sin = rotary.forward(x, position_ids) # 109μs -> 62.5μs (75.4% faster)

def test_forward_non_contiguous_position_ids():
    # Test with non-contiguous, shuffled position ids
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 4, 4)
    position_ids = torch.tensor([[3, 1, 0, 2]])
    cos, sin = rotary.forward(x, position_ids) # 108μs -> 61.3μs (76.7% faster)

def test_forward_nonzero_input():
    # Test that output does not depend on x's values
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x1 = torch.zeros(1, 2, 4)
    x2 = torch.ones(1, 2, 4)
    position_ids = torch.tensor([[0, 1]])
    cos1, sin1 = rotary.forward(x1, position_ids) # 110μs -> 59.9μs (84.1% faster)
    cos2, sin2 = rotary.forward(x2, position_ids) # 48.2μs -> 27.9μs (73.1% faster)

def test_forward_different_device():
    # Test with CUDA if available
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    if torch.cuda.is_available():
        x = torch.zeros(1, 2, 4, device="cuda")
        position_ids = torch.tensor([[0, 1]], device="cuda")
        cos, sin = rotary.forward(x, position_ids)

def test_forward_empty_sequence():
    # Test with zero-length sequence
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 0, 4)
    position_ids = torch.empty((1, 0), dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids) # 100μs -> 60.3μs (66.9% faster)

def test_forward_empty_batch():
    # Test with zero batch size
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(0, 2, 4)
    position_ids = torch.empty((0, 2), dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids) # 96.4μs -> 58.4μs (65.1% faster)

def test_forward_float16_cpu():
    # Test float16 on CPU (should work, but may fallback to float32 internally)
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 4, dtype=torch.float16)
    position_ids = torch.tensor([[0, 1]])
    cos, sin = rotary.forward(x, position_ids) # 113μs -> 66.5μs (71.2% faster)

# ----------- LARGE SCALE TEST CASES -----------

def test_forward_large_batch_and_seq():
    # Test with large batch and sequence size, but <100MB
    config = DummyConfig(hidden_size=32, num_attention_heads=4)
    rotary = MixtralRotaryEmbedding(config)
    batch_size = 32
    seq_len = 32
    dim = 8
    x = torch.zeros(batch_size, seq_len, dim)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
    cos, sin = rotary.forward(x, position_ids) # 135μs -> 75.2μs (79.6% faster)

def test_forward_large_dim():
    # Test with large dimension (head_dim=256)
    config = DummyConfig(hidden_size=1024, num_attention_heads=4, head_dim=256)
    rotary = MixtralRotaryEmbedding(config)
    batch_size = 2
    seq_len = 4
    dim = 256
    x = torch.zeros(batch_size, seq_len, dim)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
    cos, sin = rotary.forward(x, position_ids) # 119μs -> 64.7μs (84.2% faster)

def test_forward_maximum_allowed_tensor_size():
    # Test with tensor size close to 100MB
    # Each float32 element is 4 bytes, so 25_000_000 elements is 100MB
    # We'll use (batch, seq, dim) = (10, 100, 250)
    config = DummyConfig(hidden_size=1000, num_attention_heads=4, head_dim=250)
    rotary = MixtralRotaryEmbedding(config)
    batch_size = 10
    seq_len = 100
    dim = 250
    x = torch.zeros(batch_size, seq_len, dim)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
    cos, sin = rotary.forward(x, position_ids) # 579μs -> 477μs (21.3% faster)

def test_forward_randomized_large_input():
    # Test with random input and position_ids for large batch/seq
    config = DummyConfig(hidden_size=64, num_attention_heads=8)
    rotary = MixtralRotaryEmbedding(config)
    batch_size = 16
    seq_len = 32
    dim = 8
    x = torch.randn(batch_size, seq_len, dim)
    position_ids = torch.randint(0, 1000, (batch_size, seq_len))
    cos, sin = rotary.forward(x, position_ids) # 129μs -> 76.3μs (70.4% faster)

# ----------- FUNCTIONALITY / MUTATION TESTS -----------

def test_forward_mutation_detection():
    # Changing the implementation should break this test!
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 4)
    position_ids = torch.tensor([[0, 1]])
    cos, sin = rotary.forward(x, position_ids) # 112μs -> 64.4μs (75.2% faster)

def test_forward_repeatability():
    # Test that results are deterministic for same input
    config = DummyConfig(hidden_size=8, num_attention_heads=2)
    rotary = MixtralRotaryEmbedding(config)
    x = torch.zeros(1, 2, 4)
    position_ids = torch.tensor([[0, 1]])
    cos1, sin1 = rotary.forward(x, position_ids) # 110μs -> 62.3μs (76.9% faster)
    cos2, sin2 = rotary.forward(x, position_ids) # 47.9μs -> 27.6μs (73.5% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-MixtralRotaryEmbedding.forward-mhju17pq and push.

Codeflash Static Badge

The optimized code achieves a **57% speedup** by eliminating expensive tensor operations in the `forward` method of `MixtralRotaryEmbedding`. The key optimizations are:

**1. Broadcasting instead of expand + matmul:** The original code used `.expand()` to replicate tensors across dimensions, then performed matrix multiplication. The optimized version uses direct broadcasting (`position_ids[:, None, :] * inv_freq[None, :, None]`), which is more memory-efficient and computationally faster since PyTorch can optimize element-wise operations better than general matrix multiplication.

**2. Eliminated redundant type conversions:** The original code called `.float()` multiple times on the same tensors. The optimization moves all type casting to the beginning, converting `inv_freq` and `position_ids` to float32 once and reusing them.

**3. Removed unnecessary autocast context:** Since the computation is already done in float32, the `torch.autocast` wrapper adds overhead without benefit. The optimized version computes cos/sin directly and converts to the target dtype only at the end.

**4. Simplified tensor reshaping:** Instead of complex expand operations followed by transpose, the optimization uses simpler concatenation and a single transpose at the end.

The test results show consistent **70-80% speedup** across various input sizes, with the optimization being particularly effective for:
- Small to medium batch sizes (1-32 batches)
- Sequence lengths up to 256 tokens  
- Different data types (float16, float32, bfloat16)

The speedup is most pronounced in typical transformer inference scenarios where these operations are called frequently with moderate-sized tensors.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 4, 2025 00:30
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant