Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 2, 2025

📄 8% (0.08x) speedup for GraphormerGraphNodeFeature.forward in src/transformers/models/deprecated/graphormer/modeling_graphormer.py

⏱️ Runtime : 1.46 milliseconds 1.36 milliseconds (best of 45 runs)

📝 Explanation and details

The optimization replaces repeat() with expand() when creating the graph token feature tensor. This is a memory optimization that provides a 7% speedup.

Key Change:

  • Original: self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
  • Optimized: self.graph_token.weight.unsqueeze(0).expand(n_graph, -1, -1)

Why This Is Faster:

  • repeat() allocates new memory and physically copies data n_graph times, creating a fully materialized tensor
  • expand() creates a memory-efficient view that shares the underlying data, avoiding memory allocation and copying
  • Since the graph token feature is identical across all graphs in the batch, expand() is semantically equivalent but computationally cheaper

Performance Impact:
The line profiler shows the graph token creation time improved from 389,196 ns to 224,186 ns (42% faster on that line), contributing to the overall 7% speedup. This optimization is particularly effective for:

  • Larger batch sizes (more graphs = more repeated copies saved)
  • Frequent forward passes where memory allocation overhead accumulates
  • Memory-constrained environments where avoiding unnecessary allocations helps

The test results confirm consistent 6-15% improvements across various scenarios, with the largest gains seen in cases with empty graphs (15.2% faster) and single nodes (12-14% faster), where the relative cost of tensor operations is highest.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 49 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphNodeFeature


# function to test
class DummyConfig:
    """Minimal config for testing."""
    def __init__(
        self,
        num_attention_heads=4,
        num_atoms=10,
        hidden_size=8,
        pad_token_id=0,
        num_in_degree=6,
        num_out_degree=6,
    ):
        self.num_attention_heads = num_attention_heads
        self.num_atoms = num_atoms
        self.hidden_size = hidden_size
        self.pad_token_id = pad_token_id
        self.num_in_degree = num_in_degree
        self.num_out_degree = num_out_degree
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphNodeFeature

# unit tests

# ----------- BASIC TEST CASES -----------

def test_forward_basic_single_graph_single_node():
    """Test with a single graph with a single node."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1]])  # shape [1,1]
    in_degree = torch.LongTensor([[2]])
    out_degree = torch.LongTensor([[3]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output # 89.1μs -> 78.0μs (14.2% faster)
    # Check graph token is present at index 0
    torch.testing.assert_close(output[:,0,:], model.graph_token.weight)

def test_forward_basic_batch_graphs():
    """Test with a batch of graphs, each with several nodes."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1,2,3],[4,5,6]])  # shape [2,3]
    in_degree = torch.LongTensor([[1,2,3],[2,3,4]])
    out_degree = torch.LongTensor([[3,2,1],[4,3,2]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output
    # Check graph token is present at index 0 for both graphs
    torch.testing.assert_close(output[0,0,:], model.graph_token.weight[0])
    torch.testing.assert_close(output[1,0,:], model.graph_token.weight[0])

def test_forward_basic_padding():
    """Test with padding indices present."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    pad = config.pad_token_id
    input_nodes = torch.LongTensor([[pad, 1, 2]])
    in_degree = torch.LongTensor([[pad, 2, 3]])
    out_degree = torch.LongTensor([[pad, 3, 4]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output # 94.9μs -> 84.2μs (12.7% faster)
    # Check graph token is present
    torch.testing.assert_close(output[0,0,:], model.graph_token.weight[0])
    # Check that padding index does not contribute to node feature
    node_emb = model.atom_encoder(torch.LongTensor([pad]))

# ----------- EDGE TEST CASES -----------

def test_forward_edge_empty_nodes():
    """Test with zero nodes per graph (edge case)."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.empty((2,0), dtype=torch.long)
    in_degree = torch.empty((2,0), dtype=torch.long)
    out_degree = torch.empty((2,0), dtype=torch.long)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output
    # Should only contain the graph token
    torch.testing.assert_close(output[:,0,:], model.graph_token.weight)

def test_forward_edge_max_indices():
    """Test with maximum valid indices for atom/in/out_degree."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    max_atom = config.num_atoms
    max_in = config.num_in_degree - 1
    max_out = config.num_out_degree - 1
    input_nodes = torch.LongTensor([[max_atom, max_atom]])
    in_degree = torch.LongTensor([[max_in, max_in]])
    out_degree = torch.LongTensor([[max_out, max_out]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output # 94.1μs -> 85.4μs (10.1% faster)

def test_forward_edge_invalid_indices():
    """Test with out-of-range indices should raise an error."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[config.num_atoms+1]])
    in_degree = torch.LongTensor([[config.num_in_degree]])
    out_degree = torch.LongTensor([[config.num_out_degree]])
    with pytest.raises(IndexError):
        model.forward(input_nodes, in_degree, out_degree) # 82.1μs -> 82.2μs (0.142% slower)

def test_forward_edge_negative_indices():
    """Test with negative indices should raise an error."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[-1]])
    in_degree = torch.LongTensor([[-1]])
    out_degree = torch.LongTensor([[-1]])
    with pytest.raises(IndexError):
        model.forward(input_nodes, in_degree, out_degree) # 77.0μs -> 78.2μs (1.53% slower)

def test_forward_edge_mismatched_shapes():
    """Test with mismatched shapes should raise an error."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1,2],[3,4]])
    in_degree = torch.LongTensor([[1],[2]])
    out_degree = torch.LongTensor([[3,4],[5,6]])
    with pytest.raises(RuntimeError):
        model.forward(input_nodes, in_degree, out_degree)

# ----------- LARGE SCALE TEST CASES -----------

def test_forward_large_batch_and_nodes():
    """Test with a large batch and many nodes."""
    config = DummyConfig(hidden_size=16)
    model = GraphormerGraphNodeFeature(config)
    n_graph = 64
    n_node = 128
    # All indices in range
    input_nodes = torch.randint(0, config.num_atoms+1, (n_graph, n_node), dtype=torch.long)
    in_degree = torch.randint(0, config.num_in_degree, (n_graph, n_node), dtype=torch.long)
    out_degree = torch.randint(0, config.num_out_degree, (n_graph, n_node), dtype=torch.long)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output
    # Check that graph token is correct for all graphs
    for i in range(n_graph):
        torch.testing.assert_close(output[i,0,:], model.graph_token.weight[0])

def test_forward_large_hidden_size():
    """Test with a large hidden size."""
    config = DummyConfig(hidden_size=64)
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.randint(0, config.num_atoms+1, (8,16), dtype=torch.long)
    in_degree = torch.randint(0, config.num_in_degree, (8,16), dtype=torch.long)
    out_degree = torch.randint(0, config.num_out_degree, (8,16), dtype=torch.long)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output
    # Check that graph token is correct
    torch.testing.assert_close(output[0,0,:], model.graph_token.weight[0])

def test_forward_large_all_padding():
    """Test with all padding indices in a large batch."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    n_graph = 32
    n_node = 16
    pad = config.pad_token_id
    input_nodes = torch.full((n_graph, n_node), pad, dtype=torch.long)
    in_degree = torch.full((n_graph, n_node), pad, dtype=torch.long)
    out_degree = torch.full((n_graph, n_node), pad, dtype=torch.long)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output = codeflash_output
    # Graph token should be present
    torch.testing.assert_close(output[:,0,:], model.graph_token.weight)

def test_forward_large_randomized():
    """Test with randomized inputs and check determinism."""
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    torch.manual_seed(42)
    input_nodes = torch.randint(0, config.num_atoms+1, (10,10), dtype=torch.long)
    in_degree = torch.randint(0, config.num_in_degree, (10,10), dtype=torch.long)
    out_degree = torch.randint(0, config.num_out_degree, (10,10), dtype=torch.long)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output1 = codeflash_output # 101μs -> 95.5μs (6.07% faster)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); output2 = codeflash_output # 44.2μs -> 41.4μs (6.69% faster)
    # Outputs should be the same for same inputs
    torch.testing.assert_close(output1, output2)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
import torch
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphNodeFeature


# function to test
class DummyConfig:
    # Minimal config for testing purposes
    def __init__(
        self,
        num_attention_heads=2,
        num_atoms=10,
        num_in_degree=5,
        num_out_degree=5,
        hidden_size=8,
        pad_token_id=0,
    ):
        self.num_attention_heads = num_attention_heads
        self.num_atoms = num_atoms
        self.num_in_degree = num_in_degree
        self.num_out_degree = num_out_degree
        self.hidden_size = hidden_size
        self.pad_token_id = pad_token_id
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphNodeFeature

# unit tests

# --------- Basic Test Cases ---------
def test_forward_basic_single_graph_single_node():
    """
    Test with a single graph containing a single node.
    Checks output shape and basic values.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1]])  # shape: (1, 1)
    in_degree = torch.LongTensor([[2]])
    out_degree = torch.LongTensor([[3]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output # 88.8μs -> 78.7μs (12.8% faster)
    # Check that the node feature is present at position 1
    expected_node_feature = (
        model.atom_encoder(input_nodes)[0,0]
        + model.in_degree_encoder(in_degree)[0,0]
        + model.out_degree_encoder(out_degree)[0,0]
    )

def test_forward_basic_multiple_graphs_nodes():
    """
    Test with multiple graphs and multiple nodes.
    Checks output shape and graph token position.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1,2],[3,4]])  # shape: (2, 2)
    in_degree = torch.LongTensor([[0,1],[2,3]])
    out_degree = torch.LongTensor([[1,2],[3,4]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output # 89.4μs -> 81.8μs (9.33% faster)

# --------- Edge Test Cases ---------
def test_forward_with_padding():
    """
    Test with padded input (pad_token_id).
    Checks that padding does not cause errors and is handled by embedding layers.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    pad = config.pad_token_id
    input_nodes = torch.LongTensor([[pad, 1], [2, pad]])
    in_degree = torch.LongTensor([[pad, 1], [2, pad]])
    out_degree = torch.LongTensor([[pad, 1], [2, pad]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output # 89.0μs -> 83.5μs (6.57% faster)
    # Padding positions should correspond to the embedding's padding vector
    pad_emb = model.atom_encoder(torch.LongTensor([pad]))[0]

def test_forward_empty_graph():
    """
    Test with zero nodes (empty graph).
    Should return only the graph token feature.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[]]).reshape(1,0)  # shape: (1, 0)
    in_degree = torch.LongTensor([[]]).reshape(1,0)
    out_degree = torch.LongTensor([[]]).reshape(1,0)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output # 77.2μs -> 67.0μs (15.2% faster)

def test_forward_max_atom_index():
    """
    Test with the highest possible atom index.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    max_atom = config.num_atoms
    input_nodes = torch.LongTensor([[max_atom]])
    in_degree = torch.LongTensor([[1]])
    out_degree = torch.LongTensor([[2]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output # 86.5μs -> 78.0μs (10.9% faster)
    # Should not raise IndexError
    expected_node_feature = (
        model.atom_encoder(input_nodes)[0,0]
        + model.in_degree_encoder(in_degree)[0,0]
        + model.out_degree_encoder(out_degree)[0,0]
    )

def test_forward_invalid_atom_index_raises():
    """
    Test with an atom index out of bounds.
    Should raise an error from torch.nn.Embedding.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    invalid_atom = config.num_atoms + 1
    input_nodes = torch.LongTensor([[invalid_atom]])
    in_degree = torch.LongTensor([[1]])
    out_degree = torch.LongTensor([[2]])
    with pytest.raises(IndexError):
        model.forward(input_nodes, in_degree, out_degree) # 81.4μs -> 82.6μs (1.40% slower)

def test_forward_invalid_degree_index_raises():
    """
    Test with in_degree and out_degree indices out of bounds.
    Should raise IndexError.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1]])
    invalid_in_degree = config.num_in_degree
    invalid_out_degree = config.num_out_degree
    in_degree = torch.LongTensor([[invalid_in_degree]])
    out_degree = torch.LongTensor([[invalid_out_degree]])
    with pytest.raises(IndexError):
        model.forward(input_nodes, in_degree, out_degree) # 100μs -> 100μs (0.853% faster)

def test_forward_mismatched_shapes_raises():
    """
    Test with mismatched shapes for input_nodes, in_degree, out_degree.
    Should raise a RuntimeError due to broadcasting issues.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1,2]])
    in_degree = torch.LongTensor([[1]])
    out_degree = torch.LongTensor([[2,3]])
    with pytest.raises(RuntimeError):
        model.forward(input_nodes, in_degree, out_degree)

# --------- Large Scale Test Cases ---------
def test_forward_large_graphs_and_nodes():
    """
    Test with the largest allowed graph and node sizes.
    Ensures no memory issues and correct output shape.
    """
    config = DummyConfig(num_atoms=20, num_in_degree=10, num_out_degree=10, hidden_size=16)
    model = GraphormerGraphNodeFeature(config)
    n_graph = 100
    n_node = 900  # total elements: 100*900*16*4 bytes = ~5.7MB
    input_nodes = torch.randint(0, config.num_atoms+1, (n_graph, n_node), dtype=torch.long)
    in_degree = torch.randint(0, config.num_in_degree, (n_graph, n_node), dtype=torch.long)
    out_degree = torch.randint(0, config.num_out_degree, (n_graph, n_node), dtype=torch.long)
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output
    # Spot check: graph token feature for a few graphs
    for idx in [0, n_graph//2, n_graph-1]:
        pass

def test_forward_large_hidden_size():
    """
    Test with a large hidden size.
    """
    config = DummyConfig(hidden_size=128)
    model = GraphormerGraphNodeFeature(config)
    input_nodes = torch.LongTensor([[1,2,3]])
    in_degree = torch.LongTensor([[1,2,3]])
    out_degree = torch.LongTensor([[1,2,3]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out = codeflash_output # 97.9μs -> 87.0μs (12.6% faster)

def test_forward_batch_consistency():
    """
    Test that batching multiple graphs produces consistent outputs compared to processing individually.
    """
    config = DummyConfig()
    model = GraphormerGraphNodeFeature(config)
    # Prepare two graphs
    input_nodes = torch.LongTensor([[1,2],[3,4]])
    in_degree = torch.LongTensor([[1,2],[3,4]])
    out_degree = torch.LongTensor([[1,2],[3,4]])
    codeflash_output = model.forward(input_nodes, in_degree, out_degree); out_batch = codeflash_output # 91.4μs -> 85.9μs (6.30% faster)
    # Process graphs individually
    codeflash_output = model.forward(input_nodes[0:1], in_degree[0:1], out_degree[0:1]); out_g0 = codeflash_output # 41.6μs -> 37.0μs (12.4% faster)
    codeflash_output = model.forward(input_nodes[1:2], in_degree[1:2], out_degree[1:2]); out_g1 = codeflash_output # 34.6μs -> 31.6μs (9.30% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-GraphormerGraphNodeFeature.forward-mhha9fyj and push.

Codeflash Static Badge

The optimization replaces `repeat()` with `expand()` when creating the graph token feature tensor. This is a **memory optimization** that provides a 7% speedup.

**Key Change:**
- **Original:** `self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)`
- **Optimized:** `self.graph_token.weight.unsqueeze(0).expand(n_graph, -1, -1)`

**Why This Is Faster:**
- `repeat()` allocates new memory and physically copies data `n_graph` times, creating a fully materialized tensor
- `expand()` creates a memory-efficient view that shares the underlying data, avoiding memory allocation and copying
- Since the graph token feature is identical across all graphs in the batch, `expand()` is semantically equivalent but computationally cheaper

**Performance Impact:**
The line profiler shows the graph token creation time improved from 389,196 ns to 224,186 ns (42% faster on that line), contributing to the overall 7% speedup. This optimization is particularly effective for:
- **Larger batch sizes** (more graphs = more repeated copies saved)
- **Frequent forward passes** where memory allocation overhead accumulates
- **Memory-constrained environments** where avoiding unnecessary allocations helps

The test results confirm consistent 6-15% improvements across various scenarios, with the largest gains seen in cases with empty graphs (15.2% faster) and single nodes (12-14% faster), where the relative cost of tensor operations is highest.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 2, 2025 05:41
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant