Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 2, 2025

📄 5% (0.05x) speedup for GraphormerGraphAttnBias.forward in src/transformers/models/deprecated/graphormer/modeling_graphormer.py

⏱️ Runtime : 10.4 milliseconds 9.84 milliseconds (best of 60 runs)

📝 Explanation and details

The optimized code achieves a 5% speedup through several key memory and computation optimizations:

Primary Optimization - Memory-Efficient Tensor Creation:

  • Replaced attn_bias.clone().unsqueeze(1).repeat(1, self.num_heads, 1, 1) with attn_bias.unsqueeze(1).expand(-1, self.num_heads, -1, -1).clone()
  • This avoids creating large intermediate tensors during the repeat operation, using expand's lazy view semantics instead

In-Place Operations for Large Tensors:

  • Used masked_fill_ instead of tensor indexing assignment for setting padding values
  • Applied clamp_ for in-place clamping of spatial positions
  • Used div_ for in-place division in the multi-hop path
  • Replaced addition assignments with += operators throughout

Memory Layout Optimizations:

  • Added .contiguous() calls strategically to ensure optimal memory stride patterns for subsequent operations
  • This is particularly beneficial for the multi-hop edge processing where tensors undergo multiple reshaping operations

Fused Operations:

  • Combined sum and division operations using torch.sum().div_() instead of separate operations
  • Used direct += for graph attention bias updates instead of creating intermediate tensors

These optimizations are most effective for larger graphs and batch sizes (as shown in test cases with 50+ nodes achieving 10%+ speedups), where memory allocation overhead and tensor copying costs become more significant. The improvements target the most computationally expensive parts: spatial position encoding, multi-hop edge processing, and attention bias construction.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 69 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import math

# imports
import pytest
import torch
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphAttnBias


# function to test
class DummyConfig:
    def __init__(
        self,
        num_attention_heads=4,
        multi_hop_max_dist=3,
        num_edges=5,
        edge_type="multi_hop",
        num_edge_dis=3,
        num_spatial=6,
    ):
        self.num_attention_heads = num_attention_heads
        self.multi_hop_max_dist = multi_hop_max_dist
        self.num_edges = num_edges
        self.edge_type = edge_type
        self.num_edge_dis = num_edge_dis
        self.num_spatial = num_spatial
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphAttnBias

# unit tests

# Helper to create dummy inputs for the model
def generate_multi_hop_inputs(
    n_graph=1,
    n_node=3,
    num_heads=4,
    multi_hop_max_dist=3,
    num_edges=5,
    num_spatial=6
):
    # input_nodes: [n_graph, n_node]
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    # attn_bias: [n_graph, n_node+1, n_node+1]
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    # spatial_pos: [n_graph, n_node, n_node]
    spatial_pos = torch.randint(0, num_spatial, (n_graph, n_node, n_node))
    # input_edges: [n_graph, n_node, n_node, multi_hop_max_dist, num_heads]
    input_edges = torch.randint(0, num_edges + 1, (n_graph, n_node, n_node, multi_hop_max_dist, num_heads))
    # attn_edge_type: [n_graph, n_node, n_node, 1]
    attn_edge_type = torch.randint(0, num_edges + 1, (n_graph, n_node, n_node, 1))
    return input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type

def generate_single_hop_inputs(
    n_graph=1,
    n_node=3,
    num_heads=4,
    num_edges=5,
    num_spatial=6
):
    # input_nodes: [n_graph, n_node]
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    # attn_bias: [n_graph, n_node+1, n_node+1]
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    # spatial_pos: [n_graph, n_node, n_node]
    spatial_pos = torch.randint(0, num_spatial, (n_graph, n_node, n_node))
    # input_edges: Not used for single-hop
    input_edges = torch.randint(0, num_edges + 1, (n_graph, n_node, n_node, 1, num_heads))
    # attn_edge_type: [n_graph, n_node, n_node, 1]
    attn_edge_type = torch.randint(0, num_edges + 1, (n_graph, n_node, n_node, 1))
    return input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type

# -------- BASIC TEST CASES --------
def test_forward_basic_multi_hop_shape_and_type():
    """Basic: Check output shape and dtype for multi_hop edge_type."""
    config = DummyConfig(edge_type="multi_hop")
    model = GraphormerGraphAttnBias(config)
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_multi_hop_inputs()
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 256μs -> 246μs (3.84% faster)

def test_forward_basic_single_hop_shape_and_type():
    """Basic: Check output shape and dtype for single-hop edge_type."""
    config = DummyConfig(edge_type="single")
    model = GraphormerGraphAttnBias(config)
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_single_hop_inputs()
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 162μs -> 155μs (4.47% faster)

def test_forward_basic_multi_hop_deterministic():
    """Basic: Check deterministic output for the same input."""
    config = DummyConfig(edge_type="multi_hop")
    torch.manual_seed(42)
    model = GraphormerGraphAttnBias(config)
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_multi_hop_inputs()
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out1 = codeflash_output # 253μs -> 243μs (4.17% faster)
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out2 = codeflash_output # 141μs -> 135μs (4.67% faster)

def test_forward_basic_grad():
    """Basic: Check that output is differentiable wrt attn_bias."""
    config = DummyConfig(edge_type="multi_hop")
    model = GraphormerGraphAttnBias(config)
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_multi_hop_inputs()
    attn_bias.requires_grad_()
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 258μs -> 247μs (4.30% faster)
    s = out.sum()
    s.backward()

def test_forward_basic_virtual_distance_effect():
    """Basic: Changing graph_token_virtual_distance changes output."""
    config = DummyConfig(edge_type="multi_hop")
    model = GraphormerGraphAttnBias(config)
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_multi_hop_inputs()
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out1 = codeflash_output # 255μs -> 246μs (3.45% faster)
    with torch.no_grad():
        model.graph_token_virtual_distance.weight.add_(1.0)
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out2 = codeflash_output # 144μs -> 138μs (4.85% faster)

# -------- EDGE TEST CASES --------
def test_forward_edge_zero_nodes():
    """Edge: n_node = 0 (empty graph). Should handle gracefully."""
    config = DummyConfig(edge_type="multi_hop")
    model = GraphormerGraphAttnBias(config)
    # input_nodes: [1, 0], attn_bias: [1, 1, 1], spatial_pos: [1, 0, 0], input_edges: [1, 0, 0, 3, 4]
    input_nodes = torch.empty(1, 0, dtype=torch.long)
    attn_bias = torch.randn(1, 1, 1)
    spatial_pos = torch.empty(1, 0, 0, dtype=torch.long)
    input_edges = torch.empty(1, 0, 0, config.multi_hop_max_dist, config.num_attention_heads, dtype=torch.long)
    attn_edge_type = torch.empty(1, 0, 0, 1, dtype=torch.long)
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 226μs -> 221μs (2.35% faster)

def test_forward_edge_all_zero_spatial_pos():
    """Edge: All spatial_pos are zero (padding). Should not error."""
    config = DummyConfig(edge_type="multi_hop")
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 4
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.zeros(n_graph, n_node, n_node, dtype=torch.long)
    input_edges = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, config.multi_hop_max_dist, config.num_attention_heads))
    attn_edge_type = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 251μs -> 244μs (2.77% faster)

def test_forward_edge_maximum_spatial_pos():
    """Edge: All spatial_pos are max value."""
    config = DummyConfig(edge_type="multi_hop", num_spatial=6)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 4
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.full((n_graph, n_node, n_node), config.num_spatial - 1, dtype=torch.long)
    input_edges = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, config.multi_hop_max_dist, config.num_attention_heads))
    attn_edge_type = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 250μs -> 241μs (3.59% faster)

def test_forward_edge_minimum_and_maximum_edges():
    """Edge: input_edges with min (0) and max (num_edges) values."""
    config = DummyConfig(edge_type="multi_hop")
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 3
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node))
    # All zeros (padding)
    input_edges_zero = torch.zeros(n_graph, n_node, n_node, config.multi_hop_max_dist, config.num_attention_heads, dtype=torch.long)
    # All max
    input_edges_max = torch.full((n_graph, n_node, n_node, config.multi_hop_max_dist, config.num_attention_heads), config.num_edges, dtype=torch.long)
    attn_edge_type = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges_zero, attn_edge_type); out_zero = codeflash_output # 247μs -> 240μs (2.74% faster)
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges_max, attn_edge_type); out_max = codeflash_output # 140μs -> 135μs (3.63% faster)

def test_forward_edge_multi_hop_max_dist_zero():
    """Edge: multi_hop_max_dist=0 disables multi-hop, should not error."""
    config = DummyConfig(edge_type="multi_hop", multi_hop_max_dist=0)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 3
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node))
    input_edges = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1, config.num_attention_heads))
    attn_edge_type = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 239μs -> 229μs (4.44% faster)

def test_forward_edge_single_hop_vs_multi_hop():
    """Edge: single-hop and multi-hop with same inputs yield different outputs."""
    config_multi = DummyConfig(edge_type="multi_hop")
    config_single = DummyConfig(edge_type="single")
    model_multi = GraphormerGraphAttnBias(config_multi)
    model_single = GraphormerGraphAttnBias(config_single)
    n_graph, n_node = 1, 3
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.randint(0, config_multi.num_spatial, (n_graph, n_node, n_node))
    input_edges = torch.randint(0, config_multi.num_edges + 1, (n_graph, n_node, n_node, config_multi.multi_hop_max_dist, config_multi.num_attention_heads))
    attn_edge_type = torch.randint(0, config_multi.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model_multi.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out_multi = codeflash_output # 250μs -> 243μs (2.86% faster)
    codeflash_output = model_single.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out_single = codeflash_output # 91.5μs -> 86.5μs (5.69% faster)

def test_forward_edge_different_num_heads():
    """Edge: Changing num_attention_heads changes output shape."""
    config4 = DummyConfig(edge_type="multi_hop", num_attention_heads=4)
    config2 = DummyConfig(edge_type="multi_hop", num_attention_heads=2)
    model4 = GraphormerGraphAttnBias(config4)
    model2 = GraphormerGraphAttnBias(config2)
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_multi_hop_inputs(num_heads=4)
    codeflash_output = model4.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out4 = codeflash_output # 250μs -> 244μs (2.74% faster)
    input_nodes2, attn_bias2, spatial_pos2, input_edges2, attn_edge_type2 = generate_multi_hop_inputs(num_heads=2)
    codeflash_output = model2.forward(input_nodes2, attn_bias2, spatial_pos2, input_edges2, attn_edge_type2); out2 = codeflash_output # 144μs -> 141μs (2.58% faster)

# -------- LARGE SCALE TEST CASES --------
@pytest.mark.parametrize("n_graph,n_node,num_heads", [
    (8, 16, 8),
    (2, 64, 4),
    (1, 128, 2),
])
def test_forward_large_scale_multi_hop(n_graph, n_node, num_heads):
    """Large: Test with large graphs and heads, but <100MB."""
    # Estimate memory: n_graph * num_heads * (n_node+1)^2 * 4 bytes < 100MB
    config = DummyConfig(edge_type="multi_hop", num_attention_heads=num_heads, multi_hop_max_dist=3)
    model = GraphormerGraphAttnBias(config)
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node))
    input_edges = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, config.multi_hop_max_dist, num_heads))
    attn_edge_type = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 3.57ms -> 3.36ms (6.26% faster)

def test_forward_large_scale_single_hop():
    """Large: Test with single-hop edge_type and large input."""
    n_graph, n_node, num_heads = 4, 32, 6
    config = DummyConfig(edge_type="single", num_attention_heads=num_heads)
    model = GraphormerGraphAttnBias(config)
    input_nodes = torch.randint(0, 10, (n_graph, n_node))
    attn_bias = torch.randn(n_graph, n_node + 1, n_node + 1)
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node))
    input_edges = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1, num_heads))
    attn_edge_type = torch.randint(0, config.num_edges + 1, (n_graph, n_node, n_node, 1))
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out = codeflash_output # 253μs -> 233μs (8.85% faster)

def test_forward_large_scale_batch_consistency():
    """Large: Batched inputs produce consistent per-graph outputs."""
    config = DummyConfig(edge_type="multi_hop", num_attention_heads=3)
    model = GraphormerGraphAttnBias(config)
    # Create two graphs
    input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type = generate_multi_hop_inputs(n_graph=2, n_node=10, num_heads=3)
    # Run separately
    codeflash_output = model.forward(input_nodes[0:1], attn_bias[0:1], spatial_pos[0:1], input_edges[0:1], attn_edge_type[0:1]); out0 = codeflash_output # 269μs -> 260μs (3.72% faster)
    codeflash_output = model.forward(input_nodes[1:2], attn_bias[1:2], spatial_pos[1:2], input_edges[1:2], attn_edge_type[1:2]); out1 = codeflash_output # 160μs -> 154μs (3.73% faster)
    # Run batched
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); out_batched = codeflash_output # 158μs -> 151μs (4.89% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
import torch
# function to test (from the provided code)
import torch.nn as nn
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphAttnBias


class DummyConfig:
    # Minimal config class for tests
    def __init__(
        self,
        num_attention_heads=2,
        multi_hop_max_dist=0,
        num_edges=4,
        edge_type="default",
        num_edge_dis=2,
        num_spatial=4,
    ):
        self.num_attention_heads = num_attention_heads
        self.multi_hop_max_dist = multi_hop_max_dist
        self.num_edges = num_edges
        self.edge_type = edge_type
        self.num_edge_dis = num_edge_dis
        self.num_spatial = num_spatial
from transformers.models.deprecated.graphormer.modeling_graphormer import \
    GraphormerGraphAttnBias

# --------------------------
# Unit Tests for forward()
# --------------------------

# 1. Basic Test Cases

def test_forward_basic_shape_and_type():
    """
    Basic test: verifies output shape and type for default edge_type.
    """
    config = DummyConfig()
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 2, 3
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.zeros((n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.zeros((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.zeros((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 165μs -> 159μs (4.17% faster)

def test_forward_basic_nonzero_input():
    """
    Basic test: verifies that non-zero attn_bias and spatial_pos affect output.
    """
    config = DummyConfig()
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 2
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.ones((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.ones((n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.ones((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.ones((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 163μs -> 155μs (5.28% faster)

def test_forward_basic_edge_type_multi_hop():
    """
    Basic test: verifies output shape and type for multi_hop edge_type.
    """
    config = DummyConfig(edge_type="multi_hop", multi_hop_max_dist=2)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 2, 3
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.ones((n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.ones((n_graph, n_node, n_node, config.multi_hop_max_dist, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.ones((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 250μs -> 243μs (3.04% faster)

# 2. Edge Test Cases

def test_forward_edge_zero_nodes():
    """
    Edge case: n_node=0 (empty graph).
    """
    config = DummyConfig()
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 0
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.zeros((n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.zeros((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.zeros((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 151μs -> 147μs (3.01% faster)

def test_forward_edge_max_spatial_pos():
    """
    Edge case: spatial_pos at maximum value.
    """
    config = DummyConfig(num_spatial=4)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 2
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.full((n_graph, n_node, n_node), config.num_spatial-1, dtype=torch.long)
    input_edges = torch.zeros((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.zeros((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 165μs -> 158μs (4.63% faster)

def test_forward_edge_invalid_spatial_pos():
    """
    Edge case: spatial_pos contains out-of-range index (should not crash, but embedding will use padding_idx).
    """
    config = DummyConfig(num_spatial=4)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 2
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.full((n_graph, n_node, n_node), config.num_spatial+5, dtype=torch.long)  # out of range
    input_edges = torch.zeros((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.zeros((n_graph, n_node, n_node, 1), dtype=torch.long)

    # Should not raise, but output may be all zeros due to padding
    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output

def test_forward_edge_multi_hop_max_dist_zero():
    """
    Edge case: multi_hop edge_type with multi_hop_max_dist=0.
    """
    config = DummyConfig(edge_type="multi_hop", multi_hop_max_dist=0)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 2
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.ones((n_graph, n_node, n_node), dtype=torch.long)
    # input_edges shape: [n_graph, n_node, n_node, 1, num_heads]
    input_edges = torch.ones((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.ones((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 245μs -> 238μs (2.82% faster)

def test_forward_edge_all_padding():
    """
    Edge case: all input_nodes are padding (zero).
    """
    config = DummyConfig()
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 3
    input_nodes = torch.zeros((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.zeros((n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.zeros((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.zeros((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 163μs -> 157μs (3.58% faster)

def test_forward_edge_large_num_heads():
    """
    Edge case: large number of attention heads.
    """
    config = DummyConfig(num_attention_heads=32)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 1, 2
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.ones((n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.ones((n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.ones((n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 165μs -> 158μs (4.62% faster)

# 3. Large Scale Test Cases

def test_forward_large_graph():
    """
    Large scale: test with a graph of 50 nodes, 2 graphs, 8 heads.
    """
    config = DummyConfig(num_attention_heads=8)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 2, 50
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 296μs -> 269μs (10.0% faster)

def test_forward_large_multi_hop():
    """
    Large scale: test with multi_hop edge_type, 20 nodes, 3 graphs, 4 heads, max_dist=3.
    """
    config = DummyConfig(edge_type="multi_hop", multi_hop_max_dist=3, num_attention_heads=4)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 3, 20
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, config.multi_hop_max_dist, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output

def test_forward_large_batch():
    """
    Large scale: test with batch size of 16 graphs, each with 10 nodes and 6 heads.
    """
    config = DummyConfig(num_attention_heads=6)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 16, 10
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 221μs -> 199μs (10.8% faster)

def test_forward_large_maximum_allowed_tensor():
    """
    Large scale: test with the largest tensor allowed (under 100MB).
    """
    config = DummyConfig(num_attention_heads=8)
    model = GraphormerGraphAttnBias(config)
    n_graph, n_node = 8, 32  # 8*32*33*33*8*4 bytes = ~8MB per tensor, safe
    input_nodes = torch.ones((n_graph, n_node), dtype=torch.long)
    attn_bias = torch.zeros((n_graph, n_node+1, n_node+1))
    spatial_pos = torch.randint(0, config.num_spatial, (n_graph, n_node, n_node), dtype=torch.long)
    input_edges = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1, config.num_attention_heads), dtype=torch.long)
    attn_edge_type = torch.randint(0, config.num_edges+1, (n_graph, n_node, n_node, 1), dtype=torch.long)

    codeflash_output = model.forward(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type); output = codeflash_output # 391μs -> 350μs (11.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-GraphormerGraphAttnBias.forward-mhhaeuc9 and push.

Codeflash Static Badge

The optimized code achieves a **5% speedup** through several key memory and computation optimizations:

**Primary Optimization - Memory-Efficient Tensor Creation:**
- Replaced `attn_bias.clone().unsqueeze(1).repeat(1, self.num_heads, 1, 1)` with `attn_bias.unsqueeze(1).expand(-1, self.num_heads, -1, -1).clone()`
- This avoids creating large intermediate tensors during the repeat operation, using expand's lazy view semantics instead

**In-Place Operations for Large Tensors:**
- Used `masked_fill_` instead of tensor indexing assignment for setting padding values
- Applied `clamp_` for in-place clamping of spatial positions 
- Used `div_` for in-place division in the multi-hop path
- Replaced addition assignments with `+=` operators throughout

**Memory Layout Optimizations:**
- Added `.contiguous()` calls strategically to ensure optimal memory stride patterns for subsequent operations
- This is particularly beneficial for the multi-hop edge processing where tensors undergo multiple reshaping operations

**Fused Operations:**
- Combined sum and division operations using `torch.sum().div_()` instead of separate operations
- Used direct `+=` for graph attention bias updates instead of creating intermediate tensors

These optimizations are most effective for **larger graphs and batch sizes** (as shown in test cases with 50+ nodes achieving 10%+ speedups), where memory allocation overhead and tensor copying costs become more significant. The improvements target the most computationally expensive parts: spatial position encoding, multi-hop edge processing, and attention bias construction.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 2, 2025 05:45
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant