Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 28, 2025

📄 57% (0.57x) speedup for ThresholdPruner.prune in optuna/pruners/_threshold.py

⏱️ Runtime : 402 microseconds 255 microseconds (best of 296 runs)

📝 Explanation and details

The optimization replaces functools.reduce with a simple for-loop in the _is_first_in_interval_step function to find the second-largest step value.

Key Change:

  • Original: Used functools.reduce with a lambda function to iterate through intermediate_steps
  • Optimized: Used a direct for-loop with explicit variable assignment

Why This is Faster:
The functools.reduce approach creates significant overhead because:

  1. Lambda function calls: Each iteration invokes a lambda function, adding Python function call overhead
  2. Conditional expression: The ternary operator s if s > second_last_step and s != step else second_last_step is evaluated as an expression rather than a simple if-statement
  3. Functional programming overhead: reduce has additional layers of abstraction compared to a direct loop

The for-loop eliminates this overhead by using:

  • Direct variable assignment instead of function calls
  • Simple if-statement branching instead of conditional expressions
  • Minimal Python interpreter overhead

Performance Impact:
The optimization shows 57% speedup overall, with particularly strong gains in test cases with larger intermediate step collections (up to 111% faster in some large-scale tests). The improvement is most pronounced when intermediate_steps contains many elements, as the per-iteration overhead reduction compounds across all iterations.

Best Use Cases:
This optimization performs especially well for trials with many intermediate steps (large-scale tests show 58-62% improvements), making it ideal for long-running optimization tasks with frequent pruning checks.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 81 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import functools
# --- Copied code from above (ThresholdPruner and dependencies) ---
import math
from collections.abc import KeysView

# imports
import pytest
from optuna.pruners._threshold import ThresholdPruner

# --- Function to test ---
# Minimal stub/mock classes to simulate optuna's FrozenTrial and Study for testing.
# We do NOT use any external libraries except pytest and standard library.

class DummyTrial:
    """A minimal mock of optuna.trial.FrozenTrial for testing ThresholdPruner.prune."""
    def __init__(self, intermediate_values, last_step):
        # intermediate_values: dict mapping step -> value
        self.intermediate_values = intermediate_values
        self.last_step = last_step

class DummyStudy:
    """A minimal mock of optuna.study.Study for testing ThresholdPruner.prune."""
    pass

def _check_value(value):
    try:
        value = float(value)
    except (TypeError, ValueError):
        message = "The `value` argument is of type '{}' but supposed to be a float.".format(
            type(value).__name__
        )
        raise TypeError(message) from None

    return value
from optuna.pruners._threshold import ThresholdPruner

# --- Unit tests for ThresholdPruner.prune ---

# 1. Basic Test Cases
def test_prune_basic_upper_threshold():
    # Prune if value above upper
    pruner = ThresholdPruner(upper=1.0)
    trial = DummyTrial({0: 0.5, 1: 1.2}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.85μs -> 2.24μs (27.0% faster)

def test_prune_basic_lower_threshold():
    # Prune if value below lower
    pruner = ThresholdPruner(lower=0.0)
    trial = DummyTrial({0: 0.5, 1: -0.5}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.89μs -> 2.14μs (35.0% faster)

def test_prune_basic_within_threshold():
    # Do not prune if value within [lower, upper]
    pruner = ThresholdPruner(lower=0.0, upper=1.0)
    trial = DummyTrial({0: 0.5, 1: 0.7}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.87μs -> 2.10μs (36.4% faster)

def test_prune_basic_nan_value():
    # Prune if value is nan
    pruner = ThresholdPruner(lower=0.0, upper=1.0)
    trial = DummyTrial({0: 0.5, 1: float('nan')}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.78μs -> 2.10μs (32.2% faster)

def test_prune_basic_no_last_step():
    # Do not prune if last_step is None
    pruner = ThresholdPruner(lower=0.0, upper=1.0)
    trial = DummyTrial({0: 0.5}, last_step=None)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 445ns -> 480ns (7.29% slower)

# 2. Edge Test Cases
def test_prune_edge_lower_equals_upper():
    # Prune if value not exactly equal to lower==upper
    pruner = ThresholdPruner(lower=0.5, upper=0.5)
    trial = DummyTrial({0: 0.5, 1: 0.6}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.94μs -> 2.27μs (29.7% faster)

def test_prune_edge_value_exactly_on_lower():
    pruner = ThresholdPruner(lower=0.5)
    trial = DummyTrial({0: 0.5}, last_step=0)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.94μs -> 2.34μs (25.8% faster)

def test_prune_edge_value_exactly_on_upper():
    pruner = ThresholdPruner(upper=1.0)
    trial = DummyTrial({0: 1.0}, last_step=0)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.90μs -> 2.21μs (31.6% faster)

def test_prune_edge_warmup_steps():
    # Pruning disabled for steps < n_warmup_steps
    pruner = ThresholdPruner(lower=0.0, n_warmup_steps=2)
    trial = DummyTrial({0: -1.0, 1: -2.0, 2: -3.0}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 556ns -> 530ns (4.91% faster)

def test_prune_edge_interval_steps():
    # Pruning only at interval steps
    pruner = ThresholdPruner(lower=0.0, interval_steps=2)
    # Only prune at step 0, 2, 4, ...
    trial = DummyTrial({0: -1.0, 1: -2.0, 2: -3.0}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.78μs -> 1.99μs (39.4% faster)

    trial2 = DummyTrial({0: -1.0, 1: -2.0, 2: -3.0}, last_step=2)
    codeflash_output = pruner.prune(study, trial2) # 1.83μs -> 1.44μs (27.0% faster)

def test_prune_edge_unsorted_intermediate_steps():
    # Intermediate steps are unsorted, pruning logic should not depend on order
    pruner = ThresholdPruner(lower=0.0, interval_steps=2)
    trial = DummyTrial({2: -3.0, 0: -1.0, 1: -2.0}, last_step=2)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.80μs -> 2.18μs (28.7% faster)

def test_prune_edge_missing_intermediate_value_at_step():
    # If last_step is not in intermediate_values, should raise KeyError
    pruner = ThresholdPruner(lower=0.0)
    trial = DummyTrial({0: 1.0}, last_step=1)
    study = DummyStudy()
    with pytest.raises(KeyError):
        pruner.prune(study, trial) # 2.88μs -> 2.21μs (29.9% faster)

def test_prune_edge_only_upper_specified():
    pruner = ThresholdPruner(upper=1.0)
    trial = DummyTrial({0: 0.8, 1: 1.2}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 3.13μs -> 2.34μs (34.0% faster)

def test_prune_edge_only_lower_specified():
    pruner = ThresholdPruner(lower=0.0)
    trial = DummyTrial({0: 0.8, 1: -0.1}, last_step=1)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 2.87μs -> 2.08μs (37.8% faster)

def test_prune_edge_invalid_init():
    # Both lower and upper None
    with pytest.raises(TypeError):
        ThresholdPruner()

    # lower > upper
    with pytest.raises(ValueError):
        ThresholdPruner(lower=2.0, upper=1.0)

    # Negative warmup steps
    with pytest.raises(ValueError):
        ThresholdPruner(lower=0.0, n_warmup_steps=-1)

    # interval_steps < 1
    with pytest.raises(ValueError):
        ThresholdPruner(lower=0.0, interval_steps=0)

def test_prune_edge_non_float_value():
    # Should accept values that can be cast to float
    pruner = ThresholdPruner(lower='0.0', upper='1.0')
    trial = DummyTrial({0: 0.5}, last_step=0)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 3.27μs -> 2.40μs (36.2% faster)

    # Should raise TypeError if cannot cast
    with pytest.raises(TypeError):
        ThresholdPruner(lower='foo')

# 3. Large Scale Test Cases
def test_prune_large_many_steps():
    # Test with 1000 intermediate steps, pruning at every 10th step
    pruner = ThresholdPruner(lower=0.0, upper=100.0, n_warmup_steps=0, interval_steps=10)
    # All values within threshold except last
    values = {i: 50.0 for i in range(1000)}
    values[990] = 150.0  # Only prune at step 990
    trial = DummyTrial(values, last_step=990)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 46.1μs -> 28.5μs (61.7% faster)

def test_prune_large_no_prune():
    # All values within threshold, should never prune
    pruner = ThresholdPruner(lower=-1000.0, upper=1000.0, n_warmup_steps=0, interval_steps=5)
    values = {i: 0.0 for i in range(1000)}
    trial = DummyTrial(values, last_step=995)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 45.3μs -> 28.4μs (59.7% faster)

def test_prune_large_nan_in_middle():
    # Prune if value is nan at a large step
    pruner = ThresholdPruner(lower=0.0, upper=1000.0, interval_steps=100)
    values = {i: 10.0 for i in range(1000)}
    values[900] = float('nan')
    trial = DummyTrial(values, last_step=900)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 45.6μs -> 28.6μs (59.1% faster)

def test_prune_large_warmup_and_interval():
    # Prune only after warmup and at correct intervals
    pruner = ThresholdPruner(lower=0.0, upper=100.0, n_warmup_steps=50, interval_steps=100)
    values = {i: 10.0 for i in range(1000)}
    values[150] = -10.0  # Should prune at step 150
    trial = DummyTrial(values, last_step=150)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 45.7μs -> 28.3μs (61.4% faster)

def test_prune_large_unsorted_keys():
    # Unsorted intermediate_values keys, should still prune correctly
    pruner = ThresholdPruner(lower=0.0, upper=100.0, interval_steps=50)
    values = {i: 10.0 for i in range(1000)}
    values[950] = 150.0  # Should prune at step 950
    # Shuffle keys
    import random
    keys = list(values.keys())
    random.shuffle(keys)
    shuffled_values = {k: values[k] for k in keys}
    trial = DummyTrial(shuffled_values, last_step=950)
    study = DummyStudy()
    codeflash_output = pruner.prune(study, trial) # 38.7μs -> 18.3μs (111% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import math

# imports
import pytest  # used for our unit tests
from optuna.pruners._threshold import ThresholdPruner


# Minimal mocks for optuna objects
class DummyStudy:
    pass

class DummyFrozenTrial:
    def __init__(self, intermediate_values=None, last_step=None):
        # intermediate_values: dict of step -> value
        self.intermediate_values = intermediate_values if intermediate_values is not None else {}
        self.last_step = last_step
from optuna.pruners._threshold import ThresholdPruner

# --------------------------
# Unit tests for prune
# --------------------------

# Basic Test Cases

def test_prune_basic_lower_threshold():
    """Prune if intermediate value falls below lower threshold."""
    pruner = ThresholdPruner(lower=0.0)
    trial = DummyFrozenTrial({0: 1.0, 1: -1.0}, last_step=1)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 3.56μs -> 2.78μs (28.1% faster)

def test_prune_basic_upper_threshold():
    """Prune if intermediate value exceeds upper threshold."""
    pruner = ThresholdPruner(upper=2.0)
    trial = DummyFrozenTrial({0: 1.0, 1: 3.0}, last_step=1)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 3.27μs -> 2.32μs (41.0% faster)

def test_prune_basic_within_threshold():
    """Do not prune if value is within thresholds."""
    pruner = ThresholdPruner(lower=0.0, upper=2.0)
    trial = DummyFrozenTrial({0: 1.0, 1: 1.5}, last_step=1)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 3.12μs -> 2.20μs (41.8% faster)

def test_prune_basic_nan_value():
    """Prune if intermediate value is NaN."""
    pruner = ThresholdPruner(lower=0.0, upper=2.0)
    trial = DummyFrozenTrial({0: 1.0, 1: float('nan')}, last_step=1)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 2.70μs -> 2.13μs (26.7% faster)

def test_prune_basic_last_step_none():
    """Do not prune if last_step is None."""
    pruner = ThresholdPruner(lower=0.0, upper=2.0)
    trial = DummyFrozenTrial({0: 1.0}, last_step=None)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 475ns -> 466ns (1.93% faster)

# Edge Test Cases

def test_prune_edge_lower_equals_upper():
    """Do not allow lower > upper."""
    with pytest.raises(ValueError):
        ThresholdPruner(lower=2.0, upper=1.0)

def test_prune_edge_lower_none_upper_none():
    """Require at least one threshold."""
    with pytest.raises(TypeError):
        ThresholdPruner()

def test_prune_edge_warmup_steps():
    """Do not prune during warmup steps."""
    pruner = ThresholdPruner(lower=0.0, n_warmup_steps=2)
    trial = DummyFrozenTrial({0: -1.0, 1: -2.0}, last_step=1)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 577ns -> 631ns (8.56% slower)
    # After warmup, pruning should work
    trial = DummyFrozenTrial({2: -3.0}, last_step=2)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 3.15μs -> 2.14μs (47.2% faster)

def test_prune_edge_interval_steps():
    """Prune only at correct interval steps."""
    pruner = ThresholdPruner(lower=0.0, interval_steps=2)
    # Step 1 (not interval): should not prune
    trial = DummyFrozenTrial({0: 1.0, 1: -1.0}, last_step=1)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 2.60μs -> 1.83μs (42.3% faster)
    # Step 2 (interval): should prune
    trial = DummyFrozenTrial({0: 1.0, 2: -1.0}, last_step=2)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 1.77μs -> 1.38μs (28.3% faster)

def test_prune_edge_no_intermediate_value_for_step():
    """Do not prune if no value reported for current step."""
    pruner = ThresholdPruner(lower=0.0)
    trial = DummyFrozenTrial({0: 1.0}, last_step=1)
    # No value for last_step=1
    with pytest.raises(KeyError):
        pruner.prune(DummyStudy(), trial) # 2.99μs -> 2.26μs (32.5% faster)

def test_prune_edge_negative_warmup():
    """Negative warmup steps not allowed."""
    with pytest.raises(ValueError):
        ThresholdPruner(lower=0.0, n_warmup_steps=-1)

def test_prune_edge_zero_interval():
    """Interval steps < 1 not allowed."""
    with pytest.raises(ValueError):
        ThresholdPruner(lower=0.0, interval_steps=0)


def test_prune_edge_float_casting():
    """Thresholds should accept values castable to float."""
    pruner = ThresholdPruner(lower="0.0", upper="1.0")
    trial = DummyFrozenTrial({0: 0.5}, last_step=0)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 4.11μs -> 2.80μs (46.5% faster)


def test_prune_large_scale_many_steps():
    """Test pruning behavior with many steps and interval."""
    pruner = ThresholdPruner(lower=0.0, upper=100.0, n_warmup_steps=10, interval_steps=10)
    # Only steps 10, 20, ... are pruning checks
    intermediate_values = {i: 50.0 for i in range(50)}  # All values within threshold
    trial = DummyFrozenTrial(intermediate_values, last_step=49)
    # Should not prune at step 49 (not interval)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 5.44μs -> 3.52μs (54.6% faster)
    # Should prune at step 40 if value is out of bounds
    intermediate_values[40] = 200.0
    trial = DummyFrozenTrial(intermediate_values, last_step=40)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 3.27μs -> 2.13μs (53.0% faster)

def test_prune_large_scale_all_nan():
    """Test pruning when all intermediate values are NaN."""
    pruner = ThresholdPruner(lower=0.0, upper=100.0)
    intermediate_values = {i: float('nan') for i in range(100)}
    trial = DummyFrozenTrial(intermediate_values, last_step=99)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 7.59μs -> 5.13μs (47.8% faster)

def test_prune_large_scale_sparse_intermediate_values():
    """Test pruning with sparse intermediate values."""
    pruner = ThresholdPruner(lower=0.0, upper=100.0, interval_steps=10)
    # Only report values at steps 0, 10, 20, ..., 90
    intermediate_values = {i: 50.0 for i in range(0, 100, 10)}
    trial = DummyFrozenTrial(intermediate_values, last_step=90)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 3.38μs -> 2.54μs (32.8% faster)
    # Out-of-bounds value at step 90
    intermediate_values[90] = -10.0
    trial = DummyFrozenTrial(intermediate_values, last_step=90)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 1.64μs -> 1.12μs (45.9% faster)

def test_prune_large_scale_max_elements():
    """Test with maximum allowed elements (999 steps)."""
    pruner = ThresholdPruner(lower=0.0, upper=100.0)
    intermediate_values = {i: 50.0 for i in range(999)}
    trial = DummyFrozenTrial(intermediate_values, last_step=998)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 46.1μs -> 29.1μs (58.5% faster)
    # Set last value out of bounds
    intermediate_values[998] = 101.0
    trial = DummyFrozenTrial(intermediate_values, last_step=998)
    codeflash_output = pruner.prune(DummyStudy(), trial) # 44.0μs -> 27.5μs (59.9% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from optuna.pruners._threshold import ThresholdPruner

To edit these changes git checkout codeflash/optimize-ThresholdPruner.prune-mhaze4cm and push.

Codeflash

The optimization replaces `functools.reduce` with a simple for-loop in the `_is_first_in_interval_step` function to find the second-largest step value.

**Key Change:**
- **Original**: Used `functools.reduce` with a lambda function to iterate through `intermediate_steps`
- **Optimized**: Used a direct for-loop with explicit variable assignment

**Why This is Faster:**
The `functools.reduce` approach creates significant overhead because:
1. **Lambda function calls**: Each iteration invokes a lambda function, adding Python function call overhead
2. **Conditional expression**: The ternary operator `s if s > second_last_step and s != step else second_last_step` is evaluated as an expression rather than a simple if-statement
3. **Functional programming overhead**: `reduce` has additional layers of abstraction compared to a direct loop

The for-loop eliminates this overhead by using:
- Direct variable assignment instead of function calls
- Simple if-statement branching instead of conditional expressions
- Minimal Python interpreter overhead

**Performance Impact:**
The optimization shows **57% speedup** overall, with particularly strong gains in test cases with larger intermediate step collections (up to 111% faster in some large-scale tests). The improvement is most pronounced when `intermediate_steps` contains many elements, as the per-iteration overhead reduction compounds across all iterations.

**Best Use Cases:**
This optimization performs especially well for trials with many intermediate steps (large-scale tests show 58-62% improvements), making it ideal for long-running optimization tasks with frequent pruning checks.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 28, 2025 19:50
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant