Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 28, 2025

📄 21% (0.21x) speedup for JournalStorageReplayResult.get_all_trials in optuna/storages/journal/_storage.py

⏱️ Runtime : 713 microseconds 589 microseconds (best of 393 runs)

📝 Explanation and details

The optimization achieves a 20% speedup through three key improvements:

1. Short-circuit for states is None:
The optimized code adds a fast path when no state filtering is needed, using a simple list comprehension that bypasses all conditional logic. This is particularly effective for large datasets - the test results show 105-108% speedup for large-scale cases with 1000 trials when states=None.

2. Convert states to set for O(1) lookups:
When states filtering is needed, the code converts the states container to a set if it isn't already one. This changes the trial.state in states operation from potentially O(n) to O(1), providing significant benefits when filtering. Test results show 15-25% improvements for filtered queries on large datasets.

3. Local variable caching and list comprehension:
The code caches self._trials and self._study_id_to_trial_ids[study_id] as local variables, reducing attribute lookup overhead. It also replaces the explicit loop with list comprehension, which is more efficient in Python's bytecode execution.

Performance characteristics by test case:

  • Large datasets with no filtering: 105-108% faster due to the short-circuit path
  • Large datasets with filtering: 15-25% faster from set conversion and local caching
  • Small datasets: Mixed results (some 20-60% slower) due to optimization overhead, but these represent microsecond differences that are negligible in practice

The optimization is most beneficial for the common use cases: either returning all trials or filtering large numbers of trials by state.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 170 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 2 Passed
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from collections.abc import Container
from enum import Enum, auto

# imports
import pytest
from optuna.storages.journal._storage import JournalStorageReplayResult

# --- Minimal stubs for optuna classes and enums for testability ---

class TrialState(Enum):
    RUNNING = auto()
    COMPLETE = auto()
    PRUNED = auto()
    FAIL = auto()
    WAITING = auto()

class FrozenTrial:
    def __init__(self, trial_id, state):
        self._trial_id = trial_id
        self.state = state

class FrozenStudy:
    def __init__(self, study_id):
        self._study_id = study_id

# --- The function/class under test (from optuna/storages/journal/_storage.py) ---

NOT_FOUND_MSG = "Record does not exist."
from optuna.storages.journal._storage import JournalStorageReplayResult

# --- Unit tests for get_all_trials ---

# Helper function to set up a JournalStorageReplayResult with studies and trials
def setup_storage(study_trials_map):
    """
    study_trials_map: dict mapping study_id to list of (trial_id, state)
    Returns: JournalStorageReplayResult instance
    """
    storage = JournalStorageReplayResult("worker")
    for study_id, trials in study_trials_map.items():
        storage._studies[study_id] = FrozenStudy(study_id)
        storage._study_id_to_trial_ids[study_id] = []
        for trial_id, state in trials:
            storage._trials[trial_id] = FrozenTrial(trial_id, state)
            storage._study_id_to_trial_ids[study_id].append(trial_id)
            storage._trial_id_to_study_id[trial_id] = study_id
    return storage

# --- Basic Test Cases ---

def test_basic_no_trials_returns_empty():
    # Study exists but no trials
    storage = setup_storage({1: []})
    codeflash_output = storage.get_all_trials(1, None); result = codeflash_output # 662ns -> 1.17μs (43.4% slower)

def test_basic_single_trial_no_state_filter():
    # Study with one trial, no state filter
    storage = setup_storage({1: [(10, TrialState.RUNNING)]})
    codeflash_output = storage.get_all_trials(1, None); result = codeflash_output # 858ns -> 1.24μs (30.8% slower)

def test_basic_multiple_trials_no_state_filter():
    # Study with multiple trials, no state filter
    storage = setup_storage({1: [(10, TrialState.RUNNING), (11, TrialState.COMPLETE)]})
    codeflash_output = storage.get_all_trials(1, None); result = codeflash_output # 942ns -> 1.23μs (23.4% slower)
    trial_ids = {t._trial_id for t in result}

def test_basic_multiple_trials_with_state_filter():
    # Study with multiple trials, filter for COMPLETE only
    storage = setup_storage({1: [
        (10, TrialState.RUNNING),
        (11, TrialState.COMPLETE),
        (12, TrialState.PRUNED),
        (13, TrialState.COMPLETE),
    ]})
    codeflash_output = storage.get_all_trials(1, {TrialState.COMPLETE}); result = codeflash_output # 2.07μs -> 2.67μs (22.3% slower)
    trial_ids = {t._trial_id for t in result}
    for t in result:
        pass

def test_basic_multiple_trials_with_multiple_state_filter():
    # Study with multiple trials, filter for COMPLETE and PRUNED
    storage = setup_storage({1: [
        (10, TrialState.RUNNING),
        (11, TrialState.COMPLETE),
        (12, TrialState.PRUNED),
        (13, TrialState.FAIL),
        (14, TrialState.COMPLETE),
    ]})
    codeflash_output = storage.get_all_trials(1, {TrialState.COMPLETE, TrialState.PRUNED}); result = codeflash_output # 2.07μs -> 2.62μs (21.1% slower)
    trial_ids = {t._trial_id for t in result}
    for t in result:
        pass

# --- Edge Test Cases ---

def test_edge_study_id_not_found_raises_keyerror():
    # No studies at all
    storage = setup_storage({})
    with pytest.raises(KeyError) as excinfo:
        storage.get_all_trials(99, None) # 990ns -> 917ns (7.96% faster)

def test_edge_study_id_exists_but_no_trials():
    # Study exists but has no trials
    storage = setup_storage({42: []})
    codeflash_output = storage.get_all_trials(42, {TrialState.COMPLETE}); result = codeflash_output # 692ns -> 1.43μs (51.4% slower)

def test_edge_state_filter_empty_container():
    # State filter is an empty set: should return no trials
    storage = setup_storage({1: [
        (1, TrialState.COMPLETE),
        (2, TrialState.PRUNED),
    ]})
    codeflash_output = storage.get_all_trials(1, set()); result = codeflash_output # 1.56μs -> 2.17μs (28.3% slower)

def test_edge_state_filter_is_none_and_trials_have_various_states():
    # State filter is None, all trials returned regardless of state
    storage = setup_storage({1: [
        (1, TrialState.COMPLETE),
        (2, TrialState.PRUNED),
        (3, TrialState.FAIL),
        (4, TrialState.RUNNING),
        (5, TrialState.WAITING),
    ]})
    codeflash_output = storage.get_all_trials(1, None); result = codeflash_output # 1.19μs -> 1.45μs (17.6% slower)
    trial_ids = {t._trial_id for t in result}

def test_edge_state_filter_is_tuple():
    # State filter is a tuple (Container, not just set)
    storage = setup_storage({1: [
        (1, TrialState.COMPLETE),
        (2, TrialState.PRUNED),
        (3, TrialState.FAIL),
    ]})
    codeflash_output = storage.get_all_trials(1, (TrialState.PRUNED, TrialState.FAIL)); result = codeflash_output # 1.42μs -> 3.42μs (58.3% slower)
    trial_ids = {t._trial_id for t in result}

def test_edge_all_trials_filtered_out():
    # All trials filtered out by state
    storage = setup_storage({1: [
        (1, TrialState.COMPLETE),
        (2, TrialState.COMPLETE),
    ]})
    codeflash_output = storage.get_all_trials(1, {TrialState.PRUNED}); result = codeflash_output # 1.34μs -> 1.98μs (32.4% slower)

def test_edge_duplicate_trial_states():
    # Multiple trials with the same state
    storage = setup_storage({1: [
        (1, TrialState.COMPLETE),
        (2, TrialState.COMPLETE),
        (3, TrialState.PRUNED),
    ]})
    codeflash_output = storage.get_all_trials(1, {TrialState.COMPLETE}); result = codeflash_output # 1.65μs -> 2.20μs (24.8% slower)
    trial_ids = {t._trial_id for t in result}

def test_edge_multiple_studies_isolation():
    # Multiple studies, ensure only trials from correct study are returned
    storage = setup_storage({
        1: [(1, TrialState.COMPLETE), (2, TrialState.PRUNED)],
        2: [(3, TrialState.COMPLETE), (4, TrialState.RUNNING)],
    })
    codeflash_output = storage.get_all_trials(2, None); result = codeflash_output # 950ns -> 1.27μs (25.4% slower)
    trial_ids = {t._trial_id for t in result}
    for t in result:
        pass

# --- Large Scale Test Cases ---

def test_large_scale_many_trials_various_states():
    # Study with 1000 trials, various states, filter for COMPLETE
    num_trials = 1000
    trials = [(i, TrialState.COMPLETE if i % 3 == 0 else TrialState.PRUNED if i % 3 == 1 else TrialState.RUNNING) for i in range(num_trials)]
    storage = setup_storage({1: trials})
    codeflash_output = storage.get_all_trials(1, {TrialState.COMPLETE}); result = codeflash_output # 113μs -> 96.6μs (17.9% faster)
    expected_ids = {i for i in range(num_trials) if i % 3 == 0}
    result_ids = {t._trial_id for t in result}

def test_large_scale_many_studies_and_trials():
    # 10 studies, each with 100 trials, filter for PRUNED
    num_studies = 10
    num_trials_per_study = 100
    study_trials_map = {}
    for s in range(num_studies):
        trials = []
        for t in range(num_trials_per_study):
            # Alternate states
            state = [TrialState.COMPLETE, TrialState.PRUNED, TrialState.RUNNING, TrialState.FAIL][t % 4]
            trials.append((s * 1000 + t, state))
        study_trials_map[s] = trials
    storage = setup_storage(study_trials_map)
    for s in range(num_studies):
        codeflash_output = storage.get_all_trials(s, {TrialState.PRUNED}); result = codeflash_output # 127μs -> 114μs (11.4% faster)
        expected_ids = {s * 1000 + t for t in range(num_trials_per_study) if t % 4 == 1}
        result_ids = {t._trial_id for t in result}

def test_large_scale_no_trials_in_any_study():
    # 100 studies, no trials in any
    num_studies = 100
    study_trials_map = {i: [] for i in range(num_studies)}
    storage = setup_storage(study_trials_map)
    for i in range(num_studies):
        codeflash_output = storage.get_all_trials(i, None); result = codeflash_output # 21.9μs -> 36.1μs (39.3% slower)

def test_large_scale_all_trials_filtered_out():
    # 500 trials, all in state RUNNING, filter for COMPLETE
    num_trials = 500
    trials = [(i, TrialState.RUNNING) for i in range(num_trials)]
    storage = setup_storage({1: trials})
    codeflash_output = storage.get_all_trials(1, {TrialState.COMPLETE}); result = codeflash_output # 54.4μs -> 47.2μs (15.1% faster)

def test_large_scale_all_trials_returned_with_none_filter():
    # 1000 trials, mixed states, filter None (should return all)
    num_trials = 1000
    states = [TrialState.COMPLETE, TrialState.PRUNED, TrialState.RUNNING, TrialState.FAIL, TrialState.WAITING]
    trials = [(i, states[i % len(states)]) for i in range(num_trials)]
    storage = setup_storage({1: trials})
    codeflash_output = storage.get_all_trials(1, None); result = codeflash_output # 45.3μs -> 21.7μs (108% faster)
    result_ids = {t._trial_id for t in result}
    # Check that the states are as expected
    for t in result:
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from collections.abc import Container
from enum import Enum

# imports
import pytest
from optuna.storages.journal._storage import JournalStorageReplayResult


# Minimal stubs for optuna classes to allow testing.
class TrialState(Enum):
    RUNNING = 0
    COMPLETE = 1
    FAIL = 2
    WAITING = 3

class FrozenTrial:
    def __init__(self, trial_id, state):
        self.number = trial_id
        self.state = state

class FrozenStudy:
    def __init__(self, study_id):
        self.study_id = study_id

NOT_FOUND_MSG = "Record does not exist."
from optuna.storages.journal._storage import JournalStorageReplayResult

# ----------- UNIT TESTS ------------

# ---------------- BASIC TEST CASES ----------------

def test_get_all_trials_returns_all_when_states_none():
    """Test that all trials are returned when 'states' is None."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 1
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._trials[10] = FrozenTrial(10, TrialState.RUNNING)
    jsr._trials[11] = FrozenTrial(11, TrialState.COMPLETE)
    jsr._study_id_to_trial_ids[study_id] = [10, 11]
    codeflash_output = jsr.get_all_trials(study_id, None); result = codeflash_output # 1.09μs -> 1.40μs (22.4% slower)

def test_get_all_trials_returns_only_matching_states():
    """Test that only trials matching the given states are returned."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 2
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._trials[20] = FrozenTrial(20, TrialState.RUNNING)
    jsr._trials[21] = FrozenTrial(21, TrialState.COMPLETE)
    jsr._trials[22] = FrozenTrial(22, TrialState.FAIL)
    jsr._study_id_to_trial_ids[study_id] = [20, 21, 22]
    codeflash_output = jsr.get_all_trials(study_id, {TrialState.COMPLETE, TrialState.FAIL}); result = codeflash_output # 1.91μs -> 2.55μs (25.2% slower)

def test_get_all_trials_empty_trials_list():
    """Test that an empty list is returned if the study has no trials."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 3
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._study_id_to_trial_ids[study_id] = []
    codeflash_output = jsr.get_all_trials(study_id, None); result = codeflash_output # 660ns -> 1.13μs (41.6% slower)

# ---------------- EDGE TEST CASES ----------------

def test_get_all_trials_nonexistent_study_id_raises():
    """Test that KeyError is raised if study_id does not exist."""
    jsr = JournalStorageReplayResult("worker")
    with pytest.raises(KeyError) as excinfo:
        jsr.get_all_trials(999, None) # 1.04μs -> 971ns (6.69% faster)

def test_get_all_trials_states_empty_container():
    """Test with an empty container for states (should return empty list)."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 4
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._trials[40] = FrozenTrial(40, TrialState.RUNNING)
    jsr._study_id_to_trial_ids[study_id] = [40]
    codeflash_output = jsr.get_all_trials(study_id, set()); result = codeflash_output # 1.44μs -> 2.31μs (37.4% slower)

def test_get_all_trials_trial_with_unexpected_state():
    """Test that a trial with a state not in the states container is excluded."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 5
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._trials[50] = FrozenTrial(50, TrialState.WAITING)
    jsr._study_id_to_trial_ids[study_id] = [50]
    codeflash_output = jsr.get_all_trials(study_id, {TrialState.RUNNING, TrialState.COMPLETE}); result = codeflash_output # 1.20μs -> 1.85μs (35.4% slower)

def test_get_all_trials_none_states_with_one_trial():
    """Test with one trial and states=None returns that trial."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 6
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._trials[60] = FrozenTrial(60, TrialState.COMPLETE)
    jsr._study_id_to_trial_ids[study_id] = [60]
    codeflash_output = jsr.get_all_trials(study_id, None); result = codeflash_output # 891ns -> 1.26μs (29.3% slower)

def test_get_all_trials_states_is_tuple():
    """Test that function works with tuple as states container."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 7
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._trials[70] = FrozenTrial(70, TrialState.COMPLETE)
    jsr._trials[71] = FrozenTrial(71, TrialState.FAIL)
    jsr._study_id_to_trial_ids[study_id] = [70, 71]
    codeflash_output = jsr.get_all_trials(study_id, (TrialState.COMPLETE,)); result = codeflash_output # 1.25μs -> 3.23μs (61.3% slower)

def test_get_all_trials_trial_ids_not_in_trials_dict():
    """Test that missing trial_id in _trials raises KeyError."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 8
    jsr._studies[study_id] = FrozenStudy(study_id)
    jsr._study_id_to_trial_ids[study_id] = [80]  # trial 80 not in _trials
    with pytest.raises(KeyError):
        jsr.get_all_trials(study_id, None) # 1.17μs -> 1.63μs (28.3% slower)

# ---------------- LARGE SCALE TEST CASES ----------------

def test_get_all_trials_large_number_of_trials_all_returned():
    """Test with 1000 trials, states=None, all trials returned."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 9
    jsr._studies[study_id] = FrozenStudy(study_id)
    trial_ids = list(range(1000))
    jsr._study_id_to_trial_ids[study_id] = trial_ids
    for tid in trial_ids:
        # Alternate states for variety
        state = TrialState.COMPLETE if tid % 2 == 0 else TrialState.RUNNING
        jsr._trials[tid] = FrozenTrial(tid, state)
    codeflash_output = jsr.get_all_trials(study_id, None); result = codeflash_output # 45.2μs -> 22.1μs (105% faster)

def test_get_all_trials_large_number_of_trials_some_states():
    """Test with 1000 trials, filter only COMPLETE trials."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 10
    jsr._studies[study_id] = FrozenStudy(study_id)
    trial_ids = list(range(1000))
    jsr._study_id_to_trial_ids[study_id] = trial_ids
    for tid in trial_ids:
        state = TrialState.COMPLETE if tid % 2 == 0 else TrialState.RUNNING
        jsr._trials[tid] = FrozenTrial(tid, state)
    codeflash_output = jsr.get_all_trials(study_id, {TrialState.COMPLETE}); result = codeflash_output # 121μs -> 97.1μs (25.5% faster)

def test_get_all_trials_large_number_of_trials_no_matching_state():
    """Test with 1000 trials, filter for state not present."""
    jsr = JournalStorageReplayResult("worker")
    study_id = 11
    jsr._studies[study_id] = FrozenStudy(study_id)
    trial_ids = list(range(1000))
    jsr._study_id_to_trial_ids[study_id] = trial_ids
    for tid in trial_ids:
        jsr._trials[tid] = FrozenTrial(tid, TrialState.COMPLETE)
    # Use a state that is not present
    codeflash_output = jsr.get_all_trials(study_id, {TrialState.FAIL}); result = codeflash_output # 109μs -> 90.6μs (20.6% faster)

def test_get_all_trials_performance_large_scale():
    """Performance: Ensure function completes quickly for 1000 trials."""
    import time
    jsr = JournalStorageReplayResult("worker")
    study_id = 12
    jsr._studies[study_id] = FrozenStudy(study_id)
    trial_ids = list(range(1000))
    jsr._study_id_to_trial_ids[study_id] = trial_ids
    for tid in trial_ids:
        jsr._trials[tid] = FrozenTrial(tid, TrialState.COMPLETE)
    start = time.time()
    codeflash_output = jsr.get_all_trials(study_id, None); result = codeflash_output # 45.1μs -> 21.8μs (107% faster)
    end = time.time()
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from optuna.storages.journal._storage import JournalStorageReplayResult
import pytest

def test_JournalStorageReplayResult_get_all_trials():
    with pytest.raises(KeyError, match="'Record\\ does\\ not\\ exist\\.'"):
        JournalStorageReplayResult.get_all_trials(JournalStorageReplayResult(''), 0, ())
🔎 Concolic Coverage Tests and Runtime
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
codeflash_concolic_wou29s7s/tmp_sybtsfy/test_concolic_coverage.py::test_JournalStorageReplayResult_get_all_trials 1.14μs 1.17μs -2.23%⚠️

To edit these changes git checkout codeflash/optimize-JournalStorageReplayResult.get_all_trials-mhax13s8 and push.

Codeflash

The optimization achieves a 20% speedup through three key improvements:

**1. Short-circuit for `states is None`:**
The optimized code adds a fast path when no state filtering is needed, using a simple list comprehension that bypasses all conditional logic. This is particularly effective for large datasets - the test results show 105-108% speedup for large-scale cases with 1000 trials when `states=None`.

**2. Convert states to set for O(1) lookups:**
When states filtering is needed, the code converts the states container to a set if it isn't already one. This changes the `trial.state in states` operation from potentially O(n) to O(1), providing significant benefits when filtering. Test results show 15-25% improvements for filtered queries on large datasets.

**3. Local variable caching and list comprehension:**
The code caches `self._trials` and `self._study_id_to_trial_ids[study_id]` as local variables, reducing attribute lookup overhead. It also replaces the explicit loop with list comprehension, which is more efficient in Python's bytecode execution.

**Performance characteristics by test case:**
- **Large datasets with no filtering:** 105-108% faster due to the short-circuit path
- **Large datasets with filtering:** 15-25% faster from set conversion and local caching
- **Small datasets:** Mixed results (some 20-60% slower) due to optimization overhead, but these represent microsecond differences that are negligible in practice

The optimization is most beneficial for the common use cases: either returning all trials or filtering large numbers of trials by state.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 28, 2025 18:44
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash labels Oct 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant