Skip to content

Commit eed471a

Browse files
committed
Add pruner
1 parent 0345941 commit eed471a

File tree

1 file changed

+195
-0
lines changed

1 file changed

+195
-0
lines changed

Diff for: pruner.py

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import numpy as np
2+
from typing import List, Dict, Tuple, Optional
3+
import bisect
4+
from dataclasses import dataclass, field
5+
from util import predict_final_loss
6+
7+
8+
# ┌──────────────────────────────────────────────────────────┐
9+
# Trial Dataclass
10+
# └──────────────────────────────────────────────────────────┘
11+
@dataclass
12+
class Trial:
13+
"""Trial class to hold intermediate state."""
14+
15+
trial_id: int
16+
current_epoch: int = 0
17+
seed_values: Dict[int, List[float]] = field(default_factory=dict)
18+
19+
def add_value(self, seed: int, value: float) -> None:
20+
"""Add a new intermediate value for a given seed."""
21+
if seed not in self.seed_values:
22+
self.seed_values[seed] = []
23+
self.seed_values[seed].append(value)
24+
self.current_epoch = len(self.seed_values[seed])
25+
26+
27+
# ┌──────────────────────────────────────────────────────────┐
28+
# Base Pruner Class
29+
# └──────────────────────────────────────────────────────────┘
30+
class BasePruner:
31+
"""
32+
Pruner base class with Optuna-like interface.
33+
"""
34+
35+
def __init__(self):
36+
self._trials: Dict[int, Trial] = {}
37+
self._current_trial: Optional[Trial] = None
38+
39+
def register_trial(self, trial_id: int) -> None:
40+
"""Register a new trial."""
41+
self._trials[trial_id] = Trial(trial_id=trial_id)
42+
43+
def complete_trial(self, trial_id: int) -> None:
44+
"""Mark a trial as finished and clean up."""
45+
if trial_id in self._trials:
46+
if self._current_trial and self._current_trial.trial_id == trial_id:
47+
self._current_trial = None
48+
del self._trials[trial_id]
49+
50+
def report(self, trial_id: int, seed: int, epoch: int, value: float) -> None:
51+
"""Report an intermediate value for a given trial.
52+
53+
Args:
54+
trial_id: Trial identifier
55+
seed: Random seed being used
56+
epoch: Current epoch number
57+
value: Intermediate value to report (typically validation loss)
58+
"""
59+
if trial_id not in self._trials:
60+
self.register_trial(trial_id)
61+
62+
trial = self._trials[trial_id]
63+
trial.add_value(seed, value)
64+
self._current_trial = trial
65+
66+
def should_prune(self) -> bool:
67+
"""Decide whether the current trial should be pruned at the current step.
68+
69+
Returns:
70+
bool: True if the trial should be pruned
71+
"""
72+
if not self._current_trial:
73+
return False
74+
return self._should_prune_trial(self._current_trial)
75+
76+
def _should_prune_trial(self, trial: Trial) -> bool:
77+
"""Implementation specific pruning logic."""
78+
raise NotImplementedError
79+
80+
81+
# ┌──────────────────────────────────────────────────────────┐
82+
# Predicted Final Loss (PFL) Pruner
83+
# └──────────────────────────────────────────────────────────┘
84+
class PFLPruner(BasePruner):
85+
"""Predicted Final Loss (PFL) based pruner with Optuna-like interface.
86+
87+
This pruner maintains top k trials based on validation loss and prunes trials
88+
if their predicted final loss is worse than the worst saved PFL.
89+
"""
90+
91+
def __init__(
92+
self,
93+
n_startup_trials: int = 10,
94+
n_warmup_epochs: int = 10,
95+
top_k: int = 10,
96+
target_epoch: int = 50,
97+
):
98+
super().__init__()
99+
self.n_startup_trials = n_startup_trials
100+
self.n_warmup_epochs = n_warmup_epochs
101+
self.top_k = top_k
102+
self.target_epoch = target_epoch
103+
104+
self.top_pairs: List[Tuple[float, float]] = [] # List of (val_loss, pfl) pairs
105+
self.completed_trials = 0
106+
107+
def complete_trial(self, trial_id: int) -> None:
108+
"""Mark a trial as finished and check for inclusion in top-k."""
109+
if trial_id in self._trials:
110+
self.completed_trials += 1
111+
self._check_and_insert(self._trials[trial_id])
112+
super().complete_trial(trial_id)
113+
114+
def _check_and_insert(self, trial: Trial) -> None:
115+
"""Check if a trial should be inserted into top k and insert if needed."""
116+
val_loss, pfl = self._compute_trial_metrics(trial)
117+
if self._should_insert_pair(val_loss):
118+
self._insert_pair(val_loss, pfl)
119+
120+
def _compute_trial_metrics(self, trial: Trial) -> Tuple[float, float]:
121+
"""Compute average val_loss and PFL for a trial across all seeds."""
122+
if not trial.seed_values:
123+
return float("inf"), -float("inf")
124+
125+
# Average the last val_loss and PFL across seeds
126+
avg_val_loss = 0.0
127+
avg_pfl = 0.0
128+
n_seeds = len(trial.seed_values)
129+
130+
for loss_vec in trial.seed_values.values():
131+
if loss_vec: # Check if there are any losses for this seed
132+
avg_val_loss += loss_vec[-1] # Last validation loss
133+
avg_pfl += self._predict_final_loss(loss_vec)
134+
135+
avg_val_loss /= n_seeds
136+
avg_pfl /= n_seeds
137+
return avg_val_loss, avg_pfl
138+
139+
def _predict_final_loss(self, losses: List[float]) -> float:
140+
"""Predict final loss value using the loss history."""
141+
if len(losses) < 2:
142+
return -float("inf")
143+
144+
try:
145+
return (
146+
-np.log10(losses[-1])
147+
if len(losses) < 10
148+
else predict_final_loss(losses, self.target_epoch)
149+
)
150+
except:
151+
return -float("inf")
152+
153+
def _should_insert_pair(self, val_loss: float) -> bool:
154+
"""Check if a new pair should be inserted based on validation loss."""
155+
if len(self.top_pairs) < self.top_k:
156+
return True
157+
return val_loss < self.top_pairs[-1][0]
158+
159+
def _insert_pair(self, val_loss: float, pfl: float) -> None:
160+
"""Insert a new (val_loss, pfl) pair maintaining sorted order."""
161+
pair = (val_loss, pfl)
162+
163+
# Find insertion point using binary search
164+
idx = bisect.bisect_left(self.top_pairs, pair)
165+
166+
# Insert the pair
167+
if len(self.top_pairs) < self.top_k:
168+
self.top_pairs.insert(idx, pair)
169+
elif idx < self.top_k:
170+
self.top_pairs.insert(idx, pair)
171+
self.top_pairs.pop() # Remove worst pair if we exceed top_k
172+
173+
def _should_prune_trial(self, trial: Trial) -> bool:
174+
"""Implementation of trial pruning logic."""
175+
# Check if any seed has invalid loss
176+
for losses in trial.seed_values.values():
177+
if not losses or not np.isfinite(losses[-1]):
178+
return True
179+
180+
# Don't prune during warmup period
181+
if (
182+
self.completed_trials < self.n_startup_trials
183+
or trial.current_epoch <= self.n_warmup_epochs
184+
):
185+
return False
186+
187+
# Compute current metrics
188+
_, curr_pfl = self._compute_trial_metrics(trial)
189+
190+
# Prune if PFL is worse than all saved PFLs
191+
if self.top_pairs: # Only if we have recorded pairs
192+
worst_pfl = min(pair[1] for pair in self.top_pairs)
193+
return curr_pfl < worst_pfl
194+
195+
return False

0 commit comments

Comments
 (0)