diff --git a/docs/tutorials/tuning.md b/docs/tutorials/tuning.md index 5079043be..d69c09997 100644 --- a/docs/tutorials/tuning.md +++ b/docs/tutorials/tuning.md @@ -36,7 +36,11 @@ requires-python = ">3.7.1,<4.0" dependencies = [ "edsnlp[ml]>=0.15.0", - "sentencepiece>=0.1.96" + "sentencepiece>=0.1.96", + "optuna>=4.0.0", + "plotly>=5.18.0", + "ruamel.yaml<0.18.0", + "configobj", ] [project.optional-dependencies] diff --git a/edsnlp/tune.py b/edsnlp/tune.py index d926d1b49..48d305a24 100644 --- a/edsnlp/tune.py +++ b/edsnlp/tune.py @@ -9,12 +9,14 @@ import optuna import optuna.visualization as vis +from configobj import ConfigObj from confit import Cli, Config from confit.utils.collections import split_path from confit.utils.random import set_seed from optuna.importance import FanovaImportanceEvaluator, get_param_importances from optuna.pruners import MedianPruner from pydantic import BaseModel, confloat, conint +from ruamel.yaml import YAML from edsnlp.training.trainer import GenericScorer, registry, train @@ -174,6 +176,7 @@ def update_config( tuned_parameters: Dict[str, Dict], values: Optional[Dict[str, any]] = None, trial: Optional[optuna.trial.Trial] = None, + resolve: bool = True, ) -> Tuple[Dict, Dict]: """ Update a configuration dictionary with tuned hyperparameter values. @@ -248,8 +251,10 @@ def update_config( current_config = current_config[key] current_config[p_path[-1]] = value - kwargs = Config.resolve(config["train"], registry=registry, root=config) - return kwargs, config + if resolve: + kwargs = Config.resolve(config["train"], registry=registry, root=config) + return kwargs, config + return config def objective_with_param(config, tuned_parameters, trial, metric): @@ -297,6 +302,7 @@ def process_results( output_dir, viz, config, + config_path, tuned_parameters, best_params_phase_1=None, ): @@ -326,14 +332,7 @@ def process_results( for key, value in importances.items(): f.write(f" {key}: {value}\n") - config_path = os.path.join(output_dir, "config.yml") - _, updated_config = update_config( - config.copy(), - tuned_parameters, - values=best_params, - ) - updated_config.pop("tuning", None) - Config(updated_config).to_disk(config_path) + write_final_config(output_dir, config_path, tuned_parameters, best_params) if viz: vis.plot_optimization_history(study).write_html( @@ -349,8 +348,37 @@ def process_results( return best_params, importances +def write_final_config(output_dir, config_path, tuned_parameters, best_params): + path_str = str(config_path) + if path_str.endswith(".yaml") or path_str.endswith(".yml"): + yaml = YAML() + yaml.preserve_quotes = True + yaml.representer.add_representer( + type(None), + lambda self, _: self.represent_scalar("tag:yaml.org,2002:null", "null"), + ) + with open(config_path, "r", encoding="utf-8") as file: + original_config = yaml.load(file) + updated_config = update_config( + original_config, tuned_parameters, values=best_params, resolve=False + ) + with open( + os.path.join(output_dir, "config.yml"), "w", encoding="utf-8" + ) as file: + yaml.dump(updated_config, file) + else: + config = ConfigObj(config_path, encoding="utf-8") + updated_config = update_config( + dict(config), tuned_parameters, values=best_params, resolve=False + ) + config.update(updated_config) + config.filename = os.path.join(output_dir, "config.cfg") + config.write() + + def tune_two_phase( config: Dict, + config_path: str, hyperparameters: Dict[str, Dict], output_dir: str, n_trials: int, @@ -398,11 +426,13 @@ def tune_two_phase( """ n_trials_2 = n_trials // 2 n_trials_1 = n_trials - n_trials_2 + output_dir_phase_1 = os.path.join(output_dir, "phase_1") + output_dir_phase_2 = os.path.join(output_dir, "phase_2") logger.info(f"Phase 1: Tuning all hyperparameters ({n_trials_1} trials).") study = optimize(config, hyperparameters, n_trials_1, metric, study=study) best_params_phase_1, importances = process_results( - study, f"{output_dir}/phase_1", viz, config, hyperparameters + study, output_dir_phase_1, viz, config, config_path, hyperparameters ) hyperparameters_to_keep = list(importances.keys())[ @@ -436,13 +466,18 @@ def tune_two_phase( study = optimize( updated_config, hyperparameters_phase_2, n_trials_2, metric, study=study ) + if str(config_path).endswith("yaml") or str(config_path).endswith("yml"): + config_path_phase_2 = os.path.join(output_dir_phase_1, "config.yml") + else: + config_path_phase_2 = os.path.join(output_dir_phase_1, "config.cfg") process_results( study, - f"{output_dir}/phase_2", + output_dir_phase_2, viz, config, - hyperparameters, - best_params_phase_1, + config_path=config_path_phase_2, + tuned_parameters=hyperparameters, + best_params_phase_1=best_params_phase_1, ) @@ -528,7 +563,8 @@ def tune( """ setup_logging() viz = is_plotly_install() - config = load_config(config_meta["config_path"][0]) + config_path = config_meta["config_path"][0] + config = load_config(config_path) hyperparameters = {key: value.to_dict() for key, value in hyperparameters.items()} set_seed(seed) metric = split_path(metric) @@ -546,6 +582,7 @@ def tune( logger.info("Starting two-phase tuning.") tune_two_phase( config, + config_path, hyperparameters, output_dir, n_trials, @@ -566,7 +603,7 @@ def tune( "more trials to fully use GPU time budget." ) study = optimize(config, hyperparameters, n_trials, metric, study=study) - process_results(study, output_dir, viz, config, hyperparameters) + process_results(study, output_dir, viz, config, config_path, hyperparameters) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 311780ac9..217c38527 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,9 @@ dev = [ "edsnlp[ml]", "optuna>=4.0.0", "plotly>=5.18.0", # required by optuna viz + "ruamel.yaml>=0.18.0", + "configobj>=5.0.9", + ] setup = [ "typer" diff --git a/tests/tuning/config.cfg b/tests/tuning/config.cfg new file mode 100644 index 000000000..176f23deb --- /dev/null +++ b/tests/tuning/config.cfg @@ -0,0 +1,2 @@ +[train] +param1 = 1 diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index 109b49749..b3d8a38aa 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -125,7 +125,10 @@ def test_compute_importances(study): @pytest.mark.parametrize("viz", [True, False]) -def test_process_results(study, tmpdir, viz): +@pytest.mark.parametrize( + "config_path", ["tests/tuning/config.yml", "tests/tuning/config.cfg"] +) +def test_process_results(study, tmpdir, viz, config_path): output_dir = tmpdir.mkdir("output") config = { "train": { @@ -144,9 +147,8 @@ def test_process_results(study, tmpdir, viz): "step": 2, }, } - best_params, importances = process_results( - study, output_dir, viz, config, hyperparameters + study, output_dir, viz, config, config_path, hyperparameters ) assert isinstance(best_params, dict) @@ -163,7 +165,10 @@ def test_process_results(study, tmpdir, viz): assert "Params" in content assert "Importances" in content - config_file = os.path.join(output_dir, "config.yml") + if config_path.endswith("yml") or config_path.endswith("yaml"): + config_file = os.path.join(output_dir, "config.yml") + else: + config_file = os.path.join(output_dir, "config.cfg") assert os.path.exists(config_file), f"Expected file {config_file} not found" if viz: