diff --git a/changelog.md b/changelog.md index 241515e97..89e08e38f 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +### Added +- New parameter `pruning_params` to `edsnlp.tune` in order to control pruning during tuning. + ## v0.19.0 (2025-10-04) 📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later. diff --git a/edsnlp/tune.py b/edsnlp/tune.py index 0f56e0194..381ba693a 100644 --- a/edsnlp/tune.py +++ b/edsnlp/tune.py @@ -271,7 +271,7 @@ def update_config( return config -def objective_with_param(config, tuned_parameters, trial, metric): +def objective_with_param(config, tuned_parameters, trial, metric, pruning_params): kwargs, _ = update_config(config, tuned_parameters, trial=trial) seed = random.randint(0, 2**32 - 1) set_seed(seed) @@ -282,8 +282,9 @@ def on_validation_callback(all_metrics): for key in metric: score = score[key] trial.report(score, step) - if trial.should_prune(): - raise optuna.TrialPruned() + if pruning_params: + if trial.should_prune(): + raise optuna.TrialPruned() try: nlp = train(**kwargs, on_validation_callback=on_validation_callback) @@ -299,15 +300,30 @@ def on_validation_callback(all_metrics): def optimize( - config_path, tuned_parameters, n_trials, metric, checkpoint_dir, study=None + config_path, + tuned_parameters, + n_trials, + metric, + checkpoint_dir, + pruning_params, + study=None, ): def objective(trial): - return objective_with_param(config_path, tuned_parameters, trial, metric) + return objective_with_param( + config_path, tuned_parameters, trial, metric, pruning_params + ) if not study: + pruner = None + if pruning_params: + n_startup_trials = pruning_params.get("n_startup_trials", 5) + n_warmup_steps = pruning_params.get("n_warmup_steps", 5) + pruner = MedianPruner( + n_startup_trials=n_startup_trials, n_warmup_steps=n_warmup_steps + ) study = optuna.create_study( direction="maximize", - pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=2), + pruner=pruner, sampler=TPESampler(seed=random.randint(0, 2**32 - 1)), ) study.optimize( @@ -444,6 +460,7 @@ def tune_two_phase( is_fixed_n_trials: bool = False, gpu_hours: float = 1.0, skip_phase_1: bool = False, + pruning_params: Dict[str, int] = None, ) -> None: """ Perform two-phase hyperparameter tuning using Optuna. @@ -505,6 +522,7 @@ def tune_two_phase( n_trials_1, metric, checkpoint_dir, + pruning_params, study, ) best_params_phase_1, importances = process_results( @@ -551,6 +569,7 @@ def tune_two_phase( n_trials_2, metric, checkpoint_dir, + pruning_params, study, ) @@ -612,6 +631,7 @@ def tune( seed: int = 42, metric="ner.micro.f", keep_checkpoint: bool = False, + pruning_params: Optional[Dict[str, int]] = None, ): """ Perform hyperparameter tuning for a model using Optuna. @@ -652,6 +672,11 @@ def tune( Metric used to evaluate trials. Default is "ner.micro.f". keep_checkpoint : bool, optional If True, keeps the checkpoint file after tuning. Default is False. + pruning_params : dict, optional + A dictionary specifying pruning parameters: + - "n_startup_trials": Number of startup trials before pruning starts. + - "n_warmup_steps": Number of warmup steps before pruning starts. + Default is None, meaning no pruning. """ setup_logging() viz = is_plotly_install() @@ -679,6 +704,7 @@ def tune( n_trials=1, metric=metric, checkpoint_dir=checkpoint_dir, + pruning_params=pruning_params, ) n_trials = compute_n_trials(gpu_hours, compute_time_per_trial(study)) - 1 else: @@ -708,6 +734,7 @@ def tune( is_fixed_n_trials=is_fixed_n_trials, gpu_hours=gpu_hours, skip_phase_1=skip_phase_1, + pruning_params=pruning_params, ) else: logger.info("Starting single-phase tuning.") @@ -717,6 +744,7 @@ def tune( n_trials, metric, checkpoint_dir, + pruning_params, study, ) if not is_fixed_n_trials: @@ -732,6 +760,7 @@ def tune( n_trials, metric, checkpoint_dir, + pruning_params, study, ) process_results(study, output_dir, viz, config, config_path, hyperparameters)