Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

### Added
- New parameter `pruning_params` to `edsnlp.tune` in order to control pruning during tuning.

## v0.19.0 (2025-10-04)

📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later.
Expand Down
41 changes: 35 additions & 6 deletions edsnlp/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def update_config(
return config


def objective_with_param(config, tuned_parameters, trial, metric):
def objective_with_param(config, tuned_parameters, trial, metric, pruning_params):
kwargs, _ = update_config(config, tuned_parameters, trial=trial)
seed = random.randint(0, 2**32 - 1)
set_seed(seed)
Expand All @@ -282,8 +282,9 @@ def on_validation_callback(all_metrics):
for key in metric:
score = score[key]
trial.report(score, step)
if trial.should_prune():
raise optuna.TrialPruned()
if pruning_params:
if trial.should_prune():
raise optuna.TrialPruned()

try:
nlp = train(**kwargs, on_validation_callback=on_validation_callback)
Expand All @@ -299,15 +300,30 @@ def on_validation_callback(all_metrics):


def optimize(
config_path, tuned_parameters, n_trials, metric, checkpoint_dir, study=None
config_path,
tuned_parameters,
n_trials,
metric,
checkpoint_dir,
pruning_params,
study=None,
):
def objective(trial):
return objective_with_param(config_path, tuned_parameters, trial, metric)
return objective_with_param(
config_path, tuned_parameters, trial, metric, pruning_params
)

if not study:
pruner = None
if pruning_params:
n_startup_trials = pruning_params.get("n_startup_trials", 5)
n_warmup_steps = pruning_params.get("n_warmup_steps", 5)
pruner = MedianPruner(
n_startup_trials=n_startup_trials, n_warmup_steps=n_warmup_steps
)
study = optuna.create_study(
direction="maximize",
pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=2),
pruner=pruner,
sampler=TPESampler(seed=random.randint(0, 2**32 - 1)),
)
study.optimize(
Expand Down Expand Up @@ -444,6 +460,7 @@ def tune_two_phase(
is_fixed_n_trials: bool = False,
gpu_hours: float = 1.0,
skip_phase_1: bool = False,
pruning_params: Dict[str, int] = None,
) -> None:
"""
Perform two-phase hyperparameter tuning using Optuna.
Expand Down Expand Up @@ -505,6 +522,7 @@ def tune_two_phase(
n_trials_1,
metric,
checkpoint_dir,
pruning_params,
study,
)
best_params_phase_1, importances = process_results(
Expand Down Expand Up @@ -551,6 +569,7 @@ def tune_two_phase(
n_trials_2,
metric,
checkpoint_dir,
pruning_params,
study,
)

Expand Down Expand Up @@ -612,6 +631,7 @@ def tune(
seed: int = 42,
metric="ner.micro.f",
keep_checkpoint: bool = False,
pruning_params: Optional[Dict[str, int]] = None,
):
"""
Perform hyperparameter tuning for a model using Optuna.
Expand Down Expand Up @@ -652,6 +672,11 @@ def tune(
Metric used to evaluate trials. Default is "ner.micro.f".
keep_checkpoint : bool, optional
If True, keeps the checkpoint file after tuning. Default is False.
pruning_params : dict, optional
A dictionary specifying pruning parameters:
- "n_startup_trials": Number of startup trials before pruning starts.
- "n_warmup_steps": Number of warmup steps before pruning starts.
Default is None, meaning no pruning.
"""
setup_logging()
viz = is_plotly_install()
Expand Down Expand Up @@ -679,6 +704,7 @@ def tune(
n_trials=1,
metric=metric,
checkpoint_dir=checkpoint_dir,
pruning_params=pruning_params,
)
n_trials = compute_n_trials(gpu_hours, compute_time_per_trial(study)) - 1
else:
Expand Down Expand Up @@ -708,6 +734,7 @@ def tune(
is_fixed_n_trials=is_fixed_n_trials,
gpu_hours=gpu_hours,
skip_phase_1=skip_phase_1,
pruning_params=pruning_params,
)
else:
logger.info("Starting single-phase tuning.")
Expand All @@ -717,6 +744,7 @@ def tune(
n_trials,
metric,
checkpoint_dir,
pruning_params,
study,
)
if not is_fixed_n_trials:
Expand All @@ -732,6 +760,7 @@ def tune(
n_trials,
metric,
checkpoint_dir,
pruning_params,
study,
)
process_results(study, output_dir, viz, config, config_path, hyperparameters)
Expand Down
Loading