|
| 1 | +import numpy as np |
| 2 | +from typing import List, Dict, Tuple, Optional |
| 3 | +import bisect |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from util import predict_final_loss |
| 6 | + |
| 7 | + |
| 8 | +# ┌──────────────────────────────────────────────────────────┐ |
| 9 | +# Trial Dataclass |
| 10 | +# └──────────────────────────────────────────────────────────┘ |
| 11 | +@dataclass |
| 12 | +class Trial: |
| 13 | + """Trial class to hold intermediate state.""" |
| 14 | + |
| 15 | + trial_id: int |
| 16 | + current_epoch: int = 0 |
| 17 | + seed_values: Dict[int, List[float]] = field(default_factory=dict) |
| 18 | + |
| 19 | + def add_value(self, seed: int, value: float) -> None: |
| 20 | + """Add a new intermediate value for a given seed.""" |
| 21 | + if seed not in self.seed_values: |
| 22 | + self.seed_values[seed] = [] |
| 23 | + self.seed_values[seed].append(value) |
| 24 | + self.current_epoch = len(self.seed_values[seed]) |
| 25 | + |
| 26 | + |
| 27 | +# ┌──────────────────────────────────────────────────────────┐ |
| 28 | +# Base Pruner Class |
| 29 | +# └──────────────────────────────────────────────────────────┘ |
| 30 | +class BasePruner: |
| 31 | + """ |
| 32 | + Pruner base class with Optuna-like interface. |
| 33 | + """ |
| 34 | + |
| 35 | + def __init__(self): |
| 36 | + self._trials: Dict[int, Trial] = {} |
| 37 | + self._current_trial: Optional[Trial] = None |
| 38 | + |
| 39 | + def register_trial(self, trial_id: int) -> None: |
| 40 | + """Register a new trial.""" |
| 41 | + self._trials[trial_id] = Trial(trial_id=trial_id) |
| 42 | + |
| 43 | + def complete_trial(self, trial_id: int) -> None: |
| 44 | + """Mark a trial as finished and clean up.""" |
| 45 | + if trial_id in self._trials: |
| 46 | + if self._current_trial and self._current_trial.trial_id == trial_id: |
| 47 | + self._current_trial = None |
| 48 | + del self._trials[trial_id] |
| 49 | + |
| 50 | + def report(self, trial_id: int, seed: int, epoch: int, value: float) -> None: |
| 51 | + """Report an intermediate value for a given trial. |
| 52 | +
|
| 53 | + Args: |
| 54 | + trial_id: Trial identifier |
| 55 | + seed: Random seed being used |
| 56 | + epoch: Current epoch number |
| 57 | + value: Intermediate value to report (typically validation loss) |
| 58 | + """ |
| 59 | + if trial_id not in self._trials: |
| 60 | + self.register_trial(trial_id) |
| 61 | + |
| 62 | + trial = self._trials[trial_id] |
| 63 | + trial.add_value(seed, value) |
| 64 | + self._current_trial = trial |
| 65 | + |
| 66 | + def should_prune(self) -> bool: |
| 67 | + """Decide whether the current trial should be pruned at the current step. |
| 68 | +
|
| 69 | + Returns: |
| 70 | + bool: True if the trial should be pruned |
| 71 | + """ |
| 72 | + if not self._current_trial: |
| 73 | + return False |
| 74 | + return self._should_prune_trial(self._current_trial) |
| 75 | + |
| 76 | + def _should_prune_trial(self, trial: Trial) -> bool: |
| 77 | + """Implementation specific pruning logic.""" |
| 78 | + raise NotImplementedError |
| 79 | + |
| 80 | + |
| 81 | +# ┌──────────────────────────────────────────────────────────┐ |
| 82 | +# Predicted Final Loss (PFL) Pruner |
| 83 | +# └──────────────────────────────────────────────────────────┘ |
| 84 | +class PFLPruner(BasePruner): |
| 85 | + """Predicted Final Loss (PFL) based pruner with Optuna-like interface. |
| 86 | +
|
| 87 | + This pruner maintains top k trials based on validation loss and prunes trials |
| 88 | + if their predicted final loss is worse than the worst saved PFL. |
| 89 | + """ |
| 90 | + |
| 91 | + def __init__( |
| 92 | + self, |
| 93 | + n_startup_trials: int = 10, |
| 94 | + n_warmup_epochs: int = 10, |
| 95 | + top_k: int = 10, |
| 96 | + target_epoch: int = 50, |
| 97 | + ): |
| 98 | + super().__init__() |
| 99 | + self.n_startup_trials = n_startup_trials |
| 100 | + self.n_warmup_epochs = n_warmup_epochs |
| 101 | + self.top_k = top_k |
| 102 | + self.target_epoch = target_epoch |
| 103 | + |
| 104 | + self.top_pairs: List[Tuple[float, float]] = [] # List of (val_loss, pfl) pairs |
| 105 | + self.completed_trials = 0 |
| 106 | + |
| 107 | + def complete_trial(self, trial_id: int) -> None: |
| 108 | + """Mark a trial as finished and check for inclusion in top-k.""" |
| 109 | + if trial_id in self._trials: |
| 110 | + self.completed_trials += 1 |
| 111 | + self._check_and_insert(self._trials[trial_id]) |
| 112 | + super().complete_trial(trial_id) |
| 113 | + |
| 114 | + def _check_and_insert(self, trial: Trial) -> None: |
| 115 | + """Check if a trial should be inserted into top k and insert if needed.""" |
| 116 | + val_loss, pfl = self._compute_trial_metrics(trial) |
| 117 | + if self._should_insert_pair(val_loss): |
| 118 | + self._insert_pair(val_loss, pfl) |
| 119 | + |
| 120 | + def _compute_trial_metrics(self, trial: Trial) -> Tuple[float, float]: |
| 121 | + """Compute average val_loss and PFL for a trial across all seeds.""" |
| 122 | + if not trial.seed_values: |
| 123 | + return float("inf"), -float("inf") |
| 124 | + |
| 125 | + # Average the last val_loss and PFL across seeds |
| 126 | + avg_val_loss = 0.0 |
| 127 | + avg_pfl = 0.0 |
| 128 | + n_seeds = len(trial.seed_values) |
| 129 | + |
| 130 | + for loss_vec in trial.seed_values.values(): |
| 131 | + if loss_vec: # Check if there are any losses for this seed |
| 132 | + avg_val_loss += loss_vec[-1] # Last validation loss |
| 133 | + avg_pfl += self._predict_final_loss(loss_vec) |
| 134 | + |
| 135 | + avg_val_loss /= n_seeds |
| 136 | + avg_pfl /= n_seeds |
| 137 | + return avg_val_loss, avg_pfl |
| 138 | + |
| 139 | + def _predict_final_loss(self, losses: List[float]) -> float: |
| 140 | + """Predict final loss value using the loss history.""" |
| 141 | + if len(losses) < 2: |
| 142 | + return -float("inf") |
| 143 | + |
| 144 | + try: |
| 145 | + return ( |
| 146 | + -np.log10(losses[-1]) |
| 147 | + if len(losses) < 10 |
| 148 | + else predict_final_loss(losses, self.target_epoch) |
| 149 | + ) |
| 150 | + except: |
| 151 | + return -float("inf") |
| 152 | + |
| 153 | + def _should_insert_pair(self, val_loss: float) -> bool: |
| 154 | + """Check if a new pair should be inserted based on validation loss.""" |
| 155 | + if len(self.top_pairs) < self.top_k: |
| 156 | + return True |
| 157 | + return val_loss < self.top_pairs[-1][0] |
| 158 | + |
| 159 | + def _insert_pair(self, val_loss: float, pfl: float) -> None: |
| 160 | + """Insert a new (val_loss, pfl) pair maintaining sorted order.""" |
| 161 | + pair = (val_loss, pfl) |
| 162 | + |
| 163 | + # Find insertion point using binary search |
| 164 | + idx = bisect.bisect_left(self.top_pairs, pair) |
| 165 | + |
| 166 | + # Insert the pair |
| 167 | + if len(self.top_pairs) < self.top_k: |
| 168 | + self.top_pairs.insert(idx, pair) |
| 169 | + elif idx < self.top_k: |
| 170 | + self.top_pairs.insert(idx, pair) |
| 171 | + self.top_pairs.pop() # Remove worst pair if we exceed top_k |
| 172 | + |
| 173 | + def _should_prune_trial(self, trial: Trial) -> bool: |
| 174 | + """Implementation of trial pruning logic.""" |
| 175 | + # Check if any seed has invalid loss |
| 176 | + for losses in trial.seed_values.values(): |
| 177 | + if not losses or not np.isfinite(losses[-1]): |
| 178 | + return True |
| 179 | + |
| 180 | + # Don't prune during warmup period |
| 181 | + if ( |
| 182 | + self.completed_trials < self.n_startup_trials |
| 183 | + or trial.current_epoch <= self.n_warmup_epochs |
| 184 | + ): |
| 185 | + return False |
| 186 | + |
| 187 | + # Compute current metrics |
| 188 | + _, curr_pfl = self._compute_trial_metrics(trial) |
| 189 | + |
| 190 | + # Prune if PFL is worse than all saved PFLs |
| 191 | + if self.top_pairs: # Only if we have recorded pairs |
| 192 | + worst_pfl = min(pair[1] for pair in self.top_pairs) |
| 193 | + return curr_pfl < worst_pfl |
| 194 | + |
| 195 | + return False |
0 commit comments