diff --git a/README.md b/README.md index 78e8661..d80ca67 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ Available commands: * `finetune` – Fine-tune an existing pre-trained model * `evaluate` – Evaluate model performance on a dataset * `predict` – Run inference and save predictions +* `iterate` – Run hyperparameter optimization (HPO) and repeated experiments on multiple datasets/tasks --- @@ -144,3 +145,28 @@ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model. | `--output_path` | `str` | Directory where predictions are saved. | `data` | --- + +--- + +## Running Iterate + +```bash +gridfm_graphkit iterate --config path/to/config.yaml +``` + +### Arguments + +| Argument | Type | Description | Default | +| --------------- | ----- | --------------------------------------------- | --------- | +| `--config` | `str` | Path to `iterate` config file. | `None` | +| `--seed` | `int`. | Seed for reproducibility. | `None` | +| `--hpo_spec` | `namespace` | Parameters for HPO/repeated experiments | `None` | +| `--tasks` | `namespace` | MLflow run name. | `None` | +| `--model` | `namespace` | MLflow logging directory. | `None` | +| `--optimizer` | `namespace` | Dataset directory. | `None` | +| `--training` | `namespace` | Directory where predictions are saved. | `None` | +| `--callbacks` | `namespace` | Directory where predictions are saved. | `None` | + +--- +**Note:** Namespace inputs can be provided in the config or as command line arguments. If provided on the command line, namespaces inputs can be provided with `.` notation, e.g. `--hpo_spec.experiment_name my_exp`. Run `gridfm_graphkit iterate -h` for full list of inputs allows in each namespace. + diff --git a/examples/config/case118_ieee_base.yaml b/examples/config/case118_ieee_base.yaml index 41de315..dec4efa 100644 --- a/examples/config/case118_ieee_base.yaml +++ b/examples/config/case118_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case240_pserc_base.yaml b/examples/config/case240_pserc_base.yaml index 6758061..4ee4b82 100644 --- a/examples/config/case240_pserc_base.yaml +++ b/examples/config/case240_pserc_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case24_ieee_rts_base.yaml b/examples/config/case24_ieee_rts_base.yaml index b912b46..e2540db 100644 --- a/examples/config/case24_ieee_rts_base.yaml +++ b/examples/config/case24_ieee_rts_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case300_ieee_base.yaml b/examples/config/case300_ieee_base.yaml index d2fe4c4..1717bae 100644 --- a/examples/config/case300_ieee_base.yaml +++ b/examples/config/case300_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case30_ieee_base.yaml b/examples/config/case30_ieee_base.yaml index a884933..cece70e 100644 --- a/examples/config/case30_ieee_base.yaml +++ b/examples/config/case30_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case30_ieee_base_hpo.yaml b/examples/config/case30_ieee_base_hpo.yaml new file mode 100644 index 0000000..9169b5b --- /dev/null +++ b/examples/config/case30_ieee_base_hpo.yaml @@ -0,0 +1,86 @@ +seed: 42 +# data: +# networks: ["case30_ieee"] +# scenarios: [1023] +# normalization: "baseMVAnorm" +# baseMVA: 100 +# mask_type: "rnd" +# mask_value: 0.0 +# mask_ratio: 0.5 +# mask_dim: 6 +# learn_mask: False +# val_ratio: 0.1 +# test_ratio: 0.1 +# workers: 4 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 256 + input_dim: 9 + num_layers: 8 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer + model_path: "/dccstor/sentinel1/nsimumba/neso_gridfm/gridfm-graphkit/examples/models/GridFM_v0_2.pth" +training: + batch_size: 1 + epochs: 2 + losses: ["MaskedMSE", "PBE"] + loss_weights: [0.01, 0.99] + accelerator: auto + devices: auto + strategy: auto +optimizer: + type: Adam + learning_rate: 0.0001 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 +callbacks: + patience: 100 + tol: 0 + optuna_early_prune: True + +hpo_spec: + experiment_name: GPSTransformer_8 + run_name: top_run + optimization_space: + batch_size: [8, 16, 32] + learning_rate: + min: 0.000006 + max: 0.001 + type: real + log: true + n_trials: 2 + bayesian_search: True + results_folder: "/dccstor/sentinel1/nsimumba/neso_gridfm/results/iterate" + save_models: False + num_repetitions: 2 + repeat_on_best: True + report_on_best_val: True + continue_existing_experiment: True + +tasks: + - name: feature_reconstruction_base + type: feature_reconstruction + metric: "Validation loss" + direction: min + data: + data_path: "/dccstor/gridfm/PowerGraph_TP" + networks: ["case30_ieee"] + scenarios: [1023] + normalization: "baseMVAnorm" + baseMVA: 100 + mask_type: "rnd" + mask_value: 0.0 + mask_ratio: 0.5 + mask_dim: 6 + learn_mask: False + val_ratio: 0.1 + test_ratio: 0.1 + workers: 4 \ No newline at end of file diff --git a/examples/config/case39_epri_base.yaml b/examples/config/case39_epri_base.yaml index 076a25a..3c3747c 100644 --- a/examples/config/case39_epri_base.yaml +++ b/examples/config/case39_epri_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case57_ieee_base.yaml b/examples/config/case57_ieee_base.yaml index 96e7c3c..172c3f5 100644 --- a/examples/config/case57_ieee_base.yaml +++ b/examples/config/case57_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case89_pegase_base.yaml b/examples/config/case89_pegase_base.yaml index 0eef554..8175c7d 100644 --- a/examples/config/case89_pegase_base.yaml +++ b/examples/config/case89_pegase_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/gridFMv0.1_pretraining.yaml b/examples/config/gridFMv0.1_pretraining.yaml index c489e95..51ec4f6 100644 --- a/examples/config/gridFMv0.1_pretraining.yaml +++ b/examples/config/gridFMv0.1_pretraining.yaml @@ -37,11 +37,15 @@ model: pe_dim: 20 type: GNN_TransformerConv optimizer: - beta1: 0.9 - beta2: 0.999 + type: Adam learning_rate: 1.0e-05 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 seed: 200 training: batch_size: 64 diff --git a/examples/config/gridFMv0.2_pretraining.yaml b/examples/config/gridFMv0.2_pretraining.yaml index 8a804db..9fc5265 100644 --- a/examples/config/gridFMv0.2_pretraining.yaml +++ b/examples/config/gridFMv0.2_pretraining.yaml @@ -37,11 +37,15 @@ model: pe_dim: 20 type: GPSTransformer optimizer: - beta1: 0.9 - beta2: 0.999 + type: Adam learning_rate: 0.0001 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 seed: 0 training: batch_size: 64 diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index f2e3d62..6a23fd0 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -1,18 +1,25 @@ -import argparse from datetime import datetime -from gridfm_graphkit.cli import main_cli +from gridfm_graphkit.cli import main_cli, iterate_cli +from jsonargparse import ArgumentParser, Namespace -def main(): - parser = argparse.ArgumentParser( - prog="gridfm_graphkit", - description="gridfm-graphkit CLI", +from gridfm_graphkit.utils.types import ( + HyperParameterOptmizerSpec, TaskSpec, CallbackSpec, + OptimizerSpec, ModelSpec, TrainingSpec, DataSpec ) - subparsers = parser.add_subparsers(dest="command", required=True) + + +def main(): + # parser = argparse.ArgumentParser( + # prog="gridfm_graphkit", + # description="gridfm-graphkit CLI", + # ) + # subparsers = parser.add_subparsers(dest="command", required=True) exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # ---- TRAIN SUBCOMMAND ---- - train_parser = subparsers.add_parser("train", help="Run training") + # train_parser = subparsers.add_parser("train", help="Run training") + train_parser = ArgumentParser() train_parser.add_argument("--config", type=str, required=True) train_parser.add_argument("--exp_name", type=str, default=exp_name) train_parser.add_argument("--run_name", type=str, default="run") @@ -20,7 +27,8 @@ def main(): train_parser.add_argument("--data_path", type=str, default="data") # ---- FINETUNE SUBCOMMAND ---- - finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") + # finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") + finetune_parser = ArgumentParser() finetune_parser.add_argument("--config", type=str, required=True) finetune_parser.add_argument("--model_path", type=str, required=True) finetune_parser.add_argument("--exp_name", type=str, default=exp_name) @@ -29,19 +37,18 @@ def main(): finetune_parser.add_argument("--data_path", type=str, default="data") # ---- EVALUATE SUBCOMMAND ---- - evaluate_parser = subparsers.add_parser( - "evaluate", - help="Evaluate model performance", - ) - evaluate_parser.add_argument("--model_path", type=str, default=None) + # evaluate_parser = subparsers.add_parser("evaluate", help="Evaluate model performance") + evaluate_parser = ArgumentParser() evaluate_parser.add_argument("--config", type=str, required=True) + evaluate_parser.add_argument("--model_path", type=str, required=True) evaluate_parser.add_argument("--exp_name", type=str, default=exp_name) evaluate_parser.add_argument("--run_name", type=str, default="run") evaluate_parser.add_argument("--log_dir", type=str, default="mlruns") evaluate_parser.add_argument("--data_path", type=str, default="data") # ---- PREDICT SUBCOMMAND ---- - predict_parser = subparsers.add_parser("predict", help="Evaluate model performance") + # predict_parser = subparsers.add_parser("predict", help="Evaluate model performance") + predict_parser = ArgumentParser() predict_parser.add_argument("--model_path", type=str, required=None) predict_parser.add_argument("--config", type=str, required=True) predict_parser.add_argument("--exp_name", type=str, default=exp_name) @@ -50,8 +57,36 @@ def main(): predict_parser.add_argument("--data_path", type=str, default="data") predict_parser.add_argument("--output_path", type=str, default="data") + # ---- ITERATE SUBCOMMAND ---- + # iterate_parser = subparsers.add_parser("iterate", help="Run model benchmarking") + iterate_parser = ArgumentParser() + iterate_parser.add_argument("--config", action="config") + iterate_parser.add_argument("--seed", type=int) + iterate_parser.add_argument("--hpo_spec", type=HyperParameterOptmizerSpec) + iterate_parser.add_argument("--tasks", type=list[TaskSpec]) + iterate_parser.add_argument("--model", type=ModelSpec) + iterate_parser.add_argument("--optimizer", type=OptimizerSpec) + iterate_parser.add_argument("--training", type=TrainingSpec) + iterate_parser.add_argument("--callbacks", type=CallbackSpec) + + + parser = ArgumentParser( + prog="gridfm_graphkit", + description="gridfm-graphkit CLI", + ) + subcommands = parser.add_subcommands() + subcommands.add_subcommand('train', train_parser) + subcommands.add_subcommand('finetune', finetune_parser) + subcommands.add_subcommand('evaluate', finetune_parser) + subcommands.add_subcommand('predict', predict_parser) + subcommands.add_subcommand('iterate', iterate_parser) + args = parser.parse_args() - main_cli(args) + if args.subcommand == "iterate": + experiment_ids = iterate_cli(args.iterate) + return experiment_ids + else: + main_cli(args) if __name__ == "__main__": diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a7507c1..9694fd2 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -1,43 +1,18 @@ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule from gridfm_graphkit.io.param_handler import NestedNamespace -from gridfm_graphkit.training.callbacks import SaveBestModelStateDict +from gridfm_graphkit.training.callbacks import get_training_callbacks +from gridfm_graphkit.iterate import run_iterate_experiments import numpy as np import os import yaml import torch import random import pandas as pd - -from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask -from lightning.pytorch.callbacks.early_stopping import EarlyStopping -from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from gridfm_graphkit.tasks import FeatureReconstructionTask from lightning.pytorch.loggers import MLFlowLogger import lightning as L - -def get_training_callbacks(args): - early_stop_callback = EarlyStopping( - monitor="Validation loss", - min_delta=args.callbacks.tol, - patience=args.callbacks.patience, - verbose=False, - mode="min", - ) - - save_best_model_callback = SaveBestModelStateDict( - monitor="Validation loss", - mode="min", - filename="best_model_state_dict.pt", - ) - - checkpoint_callback = ModelCheckpoint( - monitor="Validation loss", # or whichever metric you track - mode="min", - save_last=True, - save_top_k=0, - ) - - return [early_stop_callback, save_best_model_callback, checkpoint_callback] +from jsonargparse import Namespace def main_cli(args): @@ -47,6 +22,9 @@ def main_cli(args): run_name=args.run_name, ) + subcommand = args.subcommand + args = args[subcommand] + with open(args.config, "r") as f: base_config = yaml.safe_load(f) @@ -62,7 +40,7 @@ def main_cli(args): litGrid.node_normalizers, litGrid.edge_normalizers, ) - if args.command != "train": + if subcommand != "train": print(f"Loading model weights from {args.model_path}") state_dict = torch.load(args.model_path) model.load_state_dict(state_dict) @@ -75,15 +53,15 @@ def main_cli(args): log_every_n_steps=1, default_root_dir=args.log_dir, max_epochs=config_args.training.epochs, - callbacks=get_training_callbacks(config_args), + callbacks=get_training_callbacks(config_args.callbacks), ) - if args.command == "train" or args.command == "finetune": + if subcommand == "train" or subcommand == "finetune": trainer.fit(model=model, datamodule=litGrid) - if args.command != "predict": + if subcommand != "predict": trainer.test(model=model, datamodule=litGrid) - if args.command == "predict": + if subcommand == "predict": predictions = trainer.predict(model=model, datamodule=litGrid) all_outputs = [] all_scenarios = [] @@ -120,3 +98,23 @@ def main_cli(args): df.to_csv(csv_path, index=False) print(f"Saved predictions to {csv_path}") + + +def iterate_cli(config_args): + # validate inputs + if config_args.seed is not None: + assert isinstance(config_args.seed, int), "seed must be an integer" + torch.manual_seed(config_args.seed) + random.seed(config_args.seed) + np.random.seed(config_args.seed) + + return run_iterate_experiments( + args=config_args, # TODO + model_spec=config_args.model, + training_spec=config_args.training, + optimizer_spec=config_args.optimizer, + callbacks_spec=config_args.callbacks, + hpo_spec=config_args.hpo_spec, + tasks=config_args.tasks, + seed=config_args.seed, + ) diff --git a/gridfm_graphkit/iterate/__init__.py b/gridfm_graphkit/iterate/__init__.py new file mode 100644 index 0000000..9599e06 --- /dev/null +++ b/gridfm_graphkit/iterate/__init__.py @@ -0,0 +1,9 @@ +from gridfm_graphkit.iterate.hpo import ( + run_hpo_experiments, + run_repeated_experiments, + run_iterate_experiments, +) +# from gridfm_graphkit.tasks.contingency_analysis import ContingencyAnalysisTask + + +__all__ = ("run_hpo_experiments", "run_repeated_experiments", "run_iterate_experiments") diff --git a/gridfm_graphkit/iterate/hpo.py b/gridfm_graphkit/iterate/hpo.py new file mode 100644 index 0000000..ffc6109 --- /dev/null +++ b/gridfm_graphkit/iterate/hpo.py @@ -0,0 +1,722 @@ +import os +from pathlib import Path +import logging +from jsonargparse import Namespace + + +import warnings +import time +from random import randint + +from functools import partial +from typing import Dict + +import mlflow +import optuna +import pandas as pd +import torch +from optuna.pruners import HyperbandPruner +from optuna.samplers import BaseSampler, RandomSampler +from ast import literal_eval + +from gridfm_graphkit.iterate.model_fitting import ( + fit_model, + fit_model_with_hparams, + inject_hparams, +) + +from gridfm_graphkit.utils.types import ( + HyperParameterOptmizerSpec, + TaskSpec, + CallbackSpec, + OptimizerSpec, + ModelSpec, + TrainingSpec, + direction_type_to_optuna, + optimization_space_type, +) + +from gridfm_graphkit.iterate.utils import ( + parse_optimization_space, + check_existing_task_parent_runs, + check_existing_experiments, + unflatten, + get_logger, + sync_mlflow_optuna, +) + + +def run_iterate_experiments( + args, # TODO: remove + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int = 42, +) -> dict[str, str, str]: + """Runs full benchmarking (hpo + repeated) for a model across multiple tasks. + + Args: + args: Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + model_spec: ModelSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + training_spec: TrainingSpec, Trainer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + optimizer_spec: OptimizerSpec, Optimizer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + callbacks_spec: CallbackSpec, Callbacks configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + hpo_spec: HyperParameterOptmizerSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + tasks: list[TaskSpec], Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + seed: int = 42, seed lightning + + Return: + Dict: + hpo_experiment_id: str, + hpo_finished_run_id: str + repeated_experiment_id: str + """ + # create folders and initialize logger + base = Path(args.hpo_spec.results_folder) + HPO_EXP_FOLDER = base / "hpo_mlflow_output" + REPEATED_EXP_FOLDER = base / "repeated_mlflow_output" + REPEATED_CSV_FOLDER = base / "repeated_csv_output" + LOG_FOLDER = base / "logs" + folders = [HPO_EXP_FOLDER, REPEATED_EXP_FOLDER, REPEATED_CSV_FOLDER, LOG_FOLDER] + for f in folders: + os.makedirs(str(f), exist_ok=True) + logger = get_logger(log_folder=str(LOG_FOLDER)) + experiment_ids = {} + + try: + # run hpo on model across multiple tasks + hpo_output = run_hpo_experiments( + args=args, # TODO: remove args + logger=logger, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + hpo_spec=hpo_spec, + tasks=tasks, + seed=seed, + storage_uri=HPO_EXP_FOLDER, + ) + + + if args.hpo_spec.num_repetitions >= 1: + # run repeated experiments + repeated_output = run_repeated_experiments( + args=args, # TODO: remove args + logger=logger, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + hpo_spec=hpo_spec, + tasks=tasks, + seed=seed, + repeated_storage_uri=REPEATED_EXP_FOLDER, + hpo_storage_uri=HPO_EXP_FOLDER, + csv_folder=REPEATED_CSV_FOLDER, + hpo_parent_run_id=hpo_output["hpo_finished_run_id"], + ) + hpo_output["repeated_experiment_id"] = repeated_output + return hpo_output + except Exception as e: + logger.info(f"Could not complete due to error {e}") + raise + + +def run_hpo_experiments( + args: Namespace, # TODO: remove + logger: logging.RootLogger, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int, + storage_uri: Path, +) -> Dict[str, str]: + """Highest level function to run hpo only for a model across multiple tasks. + + Args: + args: Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + logger: logging.RootLogger, Logger for experiment + model_spec: ModelSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + training_spec: TrainingSpec, Trainer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + optimizer_spec: OptimizerSpec, Optimizer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + callbacks_spec: CallbackSpec, Callbacks configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + hpo_spec: HyperParameterOptmizerSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + tasks: list[TaskSpec], Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + seed: int = 42, seed lightning + storage_uri: Path, location to store mlflow output from HPO + + Return: + Dict: + hpo_experiment_id: str, + hpo_finished_run_id: str + """ + # https://mlflow.org/docs/latest/ml/tracking/system-metrics/#using-the-environment-variable-to-control-system-metrics-logging + if os.getenv("MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING") is None: + os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" + + run_id: str = hpo_spec.run_id + experiment_name: str = hpo_spec.experiment_name + task_names = [task.name for task in tasks] + run_name = ( + f"top_run_{hpo_spec.experiment_name}" + if hpo_spec.run_name is None + else hpo_spec.run_name + ) + optimization_space = parse_optimization_space(hpo_spec.optimization_space) + completed_task_run_names = [] + optimize_hyperparams = True + task_run_to_id_match = {} + + storage_uri = str(storage_uri) + logger.info(f"Setting tracking URI: {storage_uri}") + mlflow.set_tracking_uri(storage_uri) + logger.info(f"Setting experiment name: {experiment_name}") + mlflow.set_experiment(experiment_name) + experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id + + if hpo_spec.continue_existing_experiment: + # find status of existing runs, and delete incomplete runs except one with the most complete tasks + existing_experiments = check_existing_experiments( + logger=logger, + storage_uri=storage_uri, + experiment_name=experiment_name, + exp_parent_run_name=run_name, + task_names=task_names, + n_trials=hpo_spec.n_trials, + ) + if existing_experiments["no_existing_runs"]: + logger.info("\nStarting new experiment from scratch") + else: + if (existing_experiments["incomplete_run_to_finish"] is not None) and ( + run_id is None + ): + logger.info("Continuing previous experiment parent run") + run_id = existing_experiments["incomplete_run_to_finish"] + logger.debug(f"incomplete_run_to_finish: {run_id=}") + experiment_id = existing_experiments["experiment_id"] + optimize_hyperparams = True + + if existing_experiments["finished_run"] is not None: + optimize_hyperparams = False + finished_run_id = existing_experiments["finished_run"] + logger.debug(f"finished_run: {run_id=}") + run_id = existing_experiments["finished_run"] + + # get previously completed tasks + completed_task_run_names, _, task_run_to_id_match = ( + check_existing_task_parent_runs( + logger, run_id, storage_uri, experiment_name, hpo_spec.n_trials + ) + ) + else: + logger.info("Starting new experiment from scratch") + + # only run hyperparameter optimization (HPO) if there are no experiments with finished HPO + if optimize_hyperparams: + if hpo_spec.bayesian_search: + sampler = None # defaults to TPESampler + else: + sampler = RandomSampler() + experiment_id, finished_run_id = _run_hpo( + args=args, # TODO + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + run_name=run_name, + run_id=run_id, + tasks=tasks, + task_names=task_names, + completed_task_run_names=completed_task_run_names, + task_run_to_id_match=task_run_to_id_match, + storage_uri=storage_uri, + experiment_name=experiment_name, + n_trials=hpo_spec.n_trials, + save_models=hpo_spec.save_models, + sampler=sampler, + test_models=hpo_spec.test_models, + optimization_space=optimization_space, + logger=logger, + seed=seed, + ) + logger.info("HPO complete\n\n\nß") + return {"hpo_experiment_id": experiment_id, "hpo_finished_run_id": finished_run_id} + + +def _run_hpo( + args: Namespace, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + tasks: list, + task_names: list[str], + completed_task_run_names: list[str], + task_run_to_id_match: dict, + storage_uri: str, + experiment_name: str, + optimization_space: optimization_space_type, + n_trials: int, + logger: logging.RootLogger, + sampler: BaseSampler | RandomSampler, + description: str | None = None, + run_name: str | None = None, + run_id: str | None = None, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, +) -> tuple[str, str]: + """Run HPO for multiple tasks under a single experiment. + + Args: + arg: Namespace, contains all parameters to be passed to model and datamodule. To be removed + model_spec: ModelSpec, contains all parameters to intiali model + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + tasks: list, + task_names: list[str], + completed_task_run_names: list[str], + task_run_to_id_match: dict, + storage_uri: str, + experiment_name: str, + optimization_space: optimization_space_type, + n_trials: int, + logger: logging.RootLogger, + sampler: BaseSampler | RandomSampler, + description: str | None = None, + run_name: str | None = None, + run_id: str | None = None, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, + + + """ + logger.info(f"Running hyperparameter optimization: {run_name=} {run_id=}") + storage_uri = str(storage_uri) + + with mlflow.start_run( + run_name=run_name, run_id=run_id, description=description + ) as run: + for task in tasks: + # only run task if it was not completed before + task_run_name = task.name + if task_run_name in completed_task_run_names: + logger.info(f"{task_run_name} already completed") + continue + else: + logger.info(f"{task_run_name} not completed. starting now") + + task_run_id = ( + task_run_to_id_match[task_run_name] + if task_run_name in task_run_to_id_match + else None + ) + best_value, metric_name, hparams = _run_hpo_per_task( + args=args, # TODO + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + logger=logger, + task=task, + storage_uri=storage_uri, + experiment_name=experiment_name, + experiment_run_id=run.info.run_id, + task_run_id=task_run_id, + optimization_space=optimization_space, + n_trials=n_trials, + save_models=save_models, + sampler=sampler, + test_models=test_models, + seed=seed, + ) + experiment_id = run.info.experiment_id + + # check completion of HPO for all tasks before proceeding to next stage + existing_experiments = check_existing_experiments( + logger=logger, + storage_uri=storage_uri, + experiment_name=experiment_name, + exp_parent_run_name=run_name, + task_names=task_names, + n_trials=n_trials, + ) + if existing_experiments["finished_run"] is not None: + finished_run_id = existing_experiments["finished_run"] + else: + logger.info("HPO is not complete. Please re-run this experiment") + raise RuntimeError + + return experiment_id, finished_run_id + + +def _run_hpo_per_task( + args: Namespace, # TODO: remove args + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + logger: logging.RootLogger, + task: TaskSpec, + storage_uri: str, + experiment_name: str, + experiment_run_id: str, + task_run_id: str | None = None, + optimization_space: optimization_space_type | None = None, + n_trials: int = 1, + save_models: bool = False, + sampler: BaseSampler | None = None, + test_models: bool = False, + seed: int = 42, +) -> tuple[str, float, dict]: + """ + Performs HPO on a single task + + Args: + args: Namespace, #TODO: remove args + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + logger: logging.RootLogger, + task: TaskSpec, + storage_uri: str, + experiment_name: str, + experiment_run_id: str, + task_run_id: str | None = None, + optimization_space: optimization_space_type | None = None, + n_trials: int = 1, + save_models: bool = False, + sampler: BaseSampler | None = None, + test_models: bool = False, + seed: int = 42, + + """ + logger.info( + f"starting backbone benchmark on task {task.name} {task_run_id=} {experiment_name=}" + ) + if storage_uri.startswith("http"): + optuna_db_path = Path(".") / "optuna_db" + else: + optuna_db_path = Path(storage_uri).parents[0] / "optuna_db" + + if not os.path.exists(optuna_db_path): + os.makedirs(optuna_db_path) + optuna_db_path = optuna_db_path / f"{experiment_name}_{experiment_run_id}" + optuna_db_path = str(optuna_db_path) + + task_run_id = sync_mlflow_optuna( + optuna_db_path=optuna_db_path, + storage_uri=storage_uri, + experiment_name=experiment_name, + task_run_id=task_run_id, + task=task, + n_trials=n_trials, + logger=logger, + ) + if task_run_id is not None: + # run_name is used only when run_id is unspecified. + run_name = None + else: + run_name = task.name + logger.info(f"start run: {run_name=} {task_run_id=}") + with mlflow.start_run(run_name=run_name, nested=True, run_id=task_run_id) as run: + logger.info(f"starting task run with id: {run.info.run_id}") + if training_spec.epochs is None: + raise Exception("Must specify epochs for training") + + # if no optimization params, just run it + if optimization_space is None: + return ( + *fit_model( + args=args, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=f"{run_name}_no_optim", + experiment_name=experiment_name, + storage_uri=storage_uri, + parent_run_id=run.info.run_id, + save_models=save_models, + test_models=test_models, + seed=seed, + logger=logger, + ), + ) + + # if optimization parameters specified, do hyperparameter tuning + study = optuna.create_study( + sampler=sampler, + direction=direction_type_to_optuna[ + task.direction + ], # in the future may want to allow user to specify this + pruner=HyperbandPruner(), + study_name=task.name, + storage="sqlite:///{}.db".format(optuna_db_path), + load_if_exists=True, + ) + + objective = partial( + fit_model_with_hparams, + args, + model_spec, + training_spec, + optimizer_spec, + callbacks_spec, + task, + optimization_space, + run_name, + experiment_name, + storage_uri, + run.info.run_id, + logger, + save_models, + test_models, + seed, + ) + + n_trials = n_trials - len(study.trials) + for trial in study.trials: + if (trial.state == optuna.trial.TrialState.FAIL) | ( + trial.state == optuna.trial.TrialState.RUNNING + ): + n_trials = n_trials + 1 + + study.optimize( + objective, + n_trials=n_trials, + catch=[torch.cuda.OutOfMemoryError], + ) + + tags = { + "seed": str(seed), + "n_trials": str(n_trials), + "model_spec": vars(model_spec), + "training_spec": vars(training_spec), + "optimizer_spec": vars(optimizer_spec), + "callbacks_spec": vars(callbacks_spec), + "data": vars(task.data), + } + mlflow.set_tags(tags) + + best_params = unflatten(study.best_trial.params) + mlflow.log_params(best_params) # unflatten + mlflow.log_metric(f"best_{task.metric}", study.best_value) + return study.best_value, task.metric, best_params + + +def run_repeated_experiments( + args: Namespace, # TODO: remove args + logger: logging.RootLogger, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int, + repeated_storage_uri: Path, + hpo_storage_uri: Path, + csv_folder: Path, + hpo_parent_run_id: str, +)-> str: + """Repeat best experiments from a benchmark run across multiple tasks. + + Args: + args: Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + logger: logging.RootLogger, Logger for experiment + model_spec: ModelSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + training_spec: TrainingSpec, Trainer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + optimizer_spec: OptimizerSpec, Optimizer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + callbacks_spec: CallbackSpec, Callbacks configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + hpo_spec: HyperParameterOptmizerSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + tasks: list[TaskSpec], Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + seed: int = 42, seed lightning + repeated_storage_uri: Path, location to store mlflow output from repeated experiments + hpo_storage_uri: Path, location where mlflow output from completed HPO experiments is stored + csv_folder: Path, Location where csv files with repeated experimnets are stored + hpo_parent_run_id: str, run id of successful HPO experiment + + Return: + str: experiment_id of completed Repeated experiment + """ + + logger.info("Starting repeated experiments") + + experiment_name = hpo_spec.experiment_name + num_repetitions = hpo_spec.num_repetitions + # find completed HPO tasks + mlflow.set_tracking_uri(str(hpo_storage_uri)) + mlflow.set_experiment(experiment_name) + + runs: list[mlflow.entities.Run] = mlflow.search_runs( + filter_string=f"tags.mlflow.parentRunId='{hpo_parent_run_id}'", output_format="list" + ) # type: ignore + logger.info(f"hpo_parent_run_id {hpo_parent_run_id}") + logger.info(f"Found runs: {[run.info.run_name for run in runs]}") + + task_names = [task.name for task in tasks] + logger.info(f"Will only run the following: {task_names}") + + table_columns = [ + "Task", + "Metric", + "Score", + "mlflow_run_name", + "mlflow_run_id", + "mlflow_run_status", + ] + + mlflow.set_tracking_uri(repeated_storage_uri) + mlflow.set_experiment(experiment_name) + output_path = csv_folder / f"{experiment_name}_repeated_exp_mlflow.csv" + if not os.path.isabs(output_path): + raise Exception(f"output_path must be absolute. got: {output_path}") + + # backbone_name = defaults.terratorch_task["model_args"]["backbone"] + with mlflow.start_run(run_name=experiment_name, run_id=None) as run: + for task in tasks: + logger.info(f"\n\ntask: {task.name}") + matching_runs = [ + run for run in runs if run.info.run_name.endswith(task.name) + ] # type: ignore + if len(matching_runs) == 0: + msg = f"No runs found for task {task.name}. Skipping." + warnings.warn(msg) + continue + if len(matching_runs) > 1: + msg = f"More than 1 run found for task {task.name}" + raise Exception(msg) + + # check if there are already results for this task and exp in the folder + if os.path.exists(output_path): + logger.info("there are previous results from repeated experiments") + existing_output = pd.read_csv(output_path, index_col=False) + existing_output = existing_output[table_columns] + existing_task_output = existing_output.loc[ + existing_output["Task"] == task.name + ].copy() + rows, cols = existing_task_output.shape + logger.info(f"rows: {rows} \t cols: {cols}") + if rows > num_repetitions: + logger.info("task has complete results, will not re-run") + continue + past_seeds = [ + int(item.split("_")[-1]) + for item in existing_task_output["mlflow_run_name"].tolist() + ] + else: + past_seeds = [] + logger.info(f"past_seeds for task: {past_seeds}") + + # get best parameters + best_params = matching_runs[0].data.params + best_params = {k: literal_eval(v) for k, v in best_params.items()} + + training_spec_with_best_hparams, optimizer_spec_with_best_hparams = ( + inject_hparams(training_spec, optimizer_spec, best_params) + ) + + experiment_info = mlflow.get_experiment_by_name(experiment_name) + seeds = [randint(1, 5000) for i in range(num_repetitions * 5)] + seeds = [seed for seed in seeds if seed not in past_seeds] + + for seed in seeds: + if len(past_seeds) >= num_repetitions: + break + + seed_run_name = f"{task.name}_{seed}" + logger.info(f"now trying: {seed_run_name}") + seed_run_data = mlflow.search_runs( + experiment_ids=[experiment_info.experiment_id], + filter_string=f'tags."mlflow.runName" LIKE "{seed_run_name}"', + output_format="list", + ) # type: ignore + if len(seed_run_data) > 0: + for item in seed_run_data: + logger.info(f"deleting existing run: {item}") + mlflow.delete_run(item.info.run_id) + + score = fit_model( + args=args, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=seed_run_name, + experiment_name=experiment_name, + storage_uri=repeated_storage_uri, + parent_run_id=run.info.run_id, + save_models=hpo_spec.save_models, + test_models=True, + seed=seed, + logger=logger, + # repeat_on_best=hpo_spec.repeat_on_best, + ) + + # check if run with name finished successfully + logger.info(f"score: {score}") + # TODO improve this sleep command - try to get a better estimate than this + time.sleep(60) + seed_run_data = mlflow.search_runs( + experiment_ids=[experiment_info.experiment_id], + filter_string=f'tags."mlflow.runName" LIKE "{seed_run_name}"', + output_format="list", + ) # type: ignore + + logger.info(f"run for task {task.name} seed {seed} complete") + if len(seed_run_data) > 0: + if seed_run_data[0].info.status != "FINISHED": + mlflow.delete_run(seed_run_data[0].info.run_id) + continue + past_seeds.append(seed) + new_data = pd.DataFrame( + { + "Task": [task.name], + "Metric": [task.metric.split("/")[-1]], + "Score": [score], + "mlflow_run_name": [seed_run_name], + "mlflow_run_id": [seed_run_data[0].info.run_id], + "mlflow_run_status": [seed_run_data[0].info.status], + } + ) + logger.info( + f"completed seeds so far for this task: {len(past_seeds)}" + ) + if os.path.exists(output_path): + logger.info( + "there are previous results from repeated experiments" + ) + + existing_output = pd.read_csv(output_path, index_col=False) + existing_output = existing_output[table_columns] + existing_output.reset_index(inplace=True) + existing_task_output = existing_output.loc[ + existing_output["Task"] == task.name + ].copy() + rows, cols = existing_task_output.shape + logger.info(f"rows: {rows} \t cols: {cols}") + if rows == 0: + logger.info("no past results for this task") + existing_output = pd.concat([existing_output, new_data], axis=0) + existing_output.reset_index(inplace=True) + existing_output = existing_output.drop( + columns=["index", "level_0"] + ) + existing_output.to_csv(output_path, index=False) + else: + new_data.to_csv(output_path, index=False) + logger.info("Repeated experiments complete \n\n\n") + return experiment_info.experiment_id diff --git a/gridfm_graphkit/iterate/model_fitting.py b/gridfm_graphkit/iterate/model_fitting.py new file mode 100644 index 0000000..68cbbeb --- /dev/null +++ b/gridfm_graphkit/iterate/model_fitting.py @@ -0,0 +1,392 @@ +""" +This module contains all the logic for fitting models +""" + +import abc +import copy +import os +import shutil +import torch +import warnings +from abc import abstractmethod +import lightning.pytorch as pl +import mlflow +import optuna +from lightning import Callback, Trainer +from lightning.pytorch.callbacks import ( + ModelCheckpoint, +) +from jsonargparse import Namespace + +import logging +from lightning.pytorch.loggers.mlflow import MLFlowLogger + +from optuna.integration import PyTorchLightningPruningCallback + + +from gridfm_graphkit.training.callbacks import get_training_callbacks +from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule + +from gridfm_graphkit.utils.types import ( + OptimizerSpec, + ModelSpec, + TrainingSpec, + TaskSpec, + CallbackSpec, + valid_task_types, + ParameterBounds, + ParameterTypeEnum, + optimization_space_type, + recursive_merge, +) + +from gridfm_graphkit.tasks import ( + FeatureReconstructionTask, +) + + +from gridfm_graphkit.iterate.utils import get_best_validation_metric, get_test_metric + + +os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = ( + "1" # disable tune loggers, will add csv and json manually. If this is not here, it will log to tensorboard automatically +) + + +class ParameterPicker(abc.ABC): + @abstractmethod + def pick_categorical(self, variable, choices): + pass + + @abstractmethod + def pick_int(self, variable, low, high): + pass + + @abstractmethod + def pick_float(self, variable, low, high, log=False): + pass + + +class OptunaParameterPicker(ParameterPicker): + def __init__(self, trial: optuna.Trial): + super().__init__() + self.trial = trial + + def pick_categorical(self, variable, choices): + return self.trial.suggest_categorical(variable, choices) + + def pick_int(self, variable, low, high): + return self.trial.suggest_int(variable, low, high) + + def pick_float(self, variable, low, high, log=False): + return self.trial.suggest_float(variable, low, high, log=log) + + +def inject_hparams( + training_spec: TrainingSpec, optimizer_spec: OptimizerSpec, config: dict +): + assert isinstance(config, dict), f"Error! Unexpected config type: {config}" + training_spec_with_hparams = copy.deepcopy(training_spec) + optimizer_spec_with_hparams = copy.deepcopy(optimizer_spec) + + recursive_merge(training_spec_with_hparams, config) + recursive_merge(optimizer_spec_with_hparams, config) + + return training_spec_with_hparams, optimizer_spec_with_hparams + + +def generate_parameters( + parameter_picker: ParameterPicker, + current_hparams: dict, + hparam_space: dict, + ignore_keys: set[str] | None = None, + dictionary_position: list[str] | None = None, +): + if ignore_keys is None: + ignore_keys = set() + if dictionary_position is None: + dictionary_position = [] + _generate_parameters( + parameter_picker, + current_hparams, + hparam_space, + ignore_keys, + dictionary_position, + ) + + +def _generate_parameters( + parameter_picker: ParameterPicker, + current_hparams: dict, + hparam_space: dict, + ignore_keys: set[str], + dictionary_position: list[str], +): + for parameter, space in hparam_space.items(): + if parameter in ignore_keys: + continue + # if its a dictionary, continue to recurse + if isinstance(space, dict): + if parameter not in current_hparams: + current_hparams[parameter] = {} + dictionary_position.append(parameter) + _generate_parameters( + parameter_picker, + current_hparams[parameter], + hparam_space[parameter], + ignore_keys, + dictionary_position, + ) + dictionary_position.pop() + # if not, get a value from the parameter_picker and insert it with the name prepended by the dictionary position + # this is important so that the full path of the parameter is used + # this will avoid confusion between parameters with the same name but from different components + else: + full_parameter_name = ".".join(dictionary_position + [parameter]) + if isinstance(space, list): + suggestion = parameter_picker.pick_categorical( + full_parameter_name, space + ) + current_hparams[parameter] = suggestion + elif isinstance(space, ParameterBounds): + match space.type: + case ParameterTypeEnum.integer: + current_hparams[parameter] = parameter_picker.pick_int( + full_parameter_name, + int(space.min), + int(space.max), + ) + case ParameterTypeEnum.real: + current_hparams[parameter] = parameter_picker.pick_float( + full_parameter_name, space.min, space.max, log=space.log + ) + case _: + raise Exception( + f"Type {space.type} not recognized. Suggest one of {[e.value for e in ParameterTypeEnum]}" + ) + else: + raise Exception( + "Leaves of optimization space must be lists or ParameterBounds" + ) + + +def launch_training( + trainer: Trainer, + model: FeatureReconstructionTask, # TODO: create basetask in tasks folder + optimizer_spec: OptimizerSpec, + datamodule: LitGridDataModule, + run_name: str, + experiment_name: str, + metric: str, + storage_uri: str, + parent_run_id: str, + direction: str, + test_models: bool, + delete_models_after_testing: bool, +) -> float: + with mlflow.start_run(run_name=run_name, nested=True) as run: + mlflow.set_tag("mlflow.parentRunId", parent_run_id) + # explicitly log batch_size. Since it is not a model param, it will not be logged + mlflow.log_param("batch_size", datamodule.batch_size) + + trainer.logger = MLFlowLogger( + experiment_name=experiment_name, + run_id=run.info.run_id, + save_dir=storage_uri, + log_model=not delete_models_after_testing, + ) + trainer.fit(model, datamodule=datamodule) + + if test_models: + trainer.test( + model=model, + # ckpt_path="best", + datamodule=datamodule + ) + metric = metric.replace("Validation", "Test") + output = get_test_metric( + storage_uri=str(storage_uri), + run=run, + metric=metric, + direction=direction, + ) + else: + output = get_best_validation_metric( + storage_uri=str(storage_uri), + run=run, + metric=metric, + direction=direction, + ) + + if delete_models_after_testing: + # delete the checkpoints folder in the run + ckpts_folder = os.path.join( + trainer.logger.save_dir, + str(trainer.logger.name), + trainer.logger.version, + "checkpoints", + ) + shutil.rmtree(ckpts_folder) + + return output + + +def fit_model( + args: Namespace, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + task: TaskSpec, + run_name: str, + experiment_name: str, + storage_uri: str, + parent_run_id: str, + logger: logging.RootLogger, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, + trial: optuna.Trial | None = None, +) -> float: + pl.seed_everything(seed, workers=True) + training_spec_copy = copy.deepcopy(training_spec) + + # get callbacks + callbacks: list[Callback] = get_training_callbacks(callbacks_spec) + if callbacks_spec.optuna_early_prune and trial is not None: + callbacks.append( + PyTorchLightningPruningCallback(trial, monitor="Validation loss") + ) + if len(callbacks) > 0: + warnings.warn( + "Callbacks passed to trainer. Make sure these are stateless, as they will not be reinitialized for each task!" + ) + + delete_models_after_testing = False + if test_models and not save_models: + # we need to save the models during training to be able to test but can be deleted afterwards + save_models = True + delete_models_after_testing = True + + if save_models: + callbacks = [ + cb + for cb in callbacks + if not (isinstance(cb, ModelCheckpoint) and cb.monitor==task.metric) + ] + callbacks.append( + ModelCheckpoint( + monitor=task.metric, + mode=task.direction, + save_top_k=1, + ) + ) + enable_checkpointing = True + else: + callbacks = [cb for cb in callbacks if not isinstance(cb, ModelCheckpoint)] + enable_checkpointing = False + + + # # initialize datamodule + args.data = task.data + datamodule = LitGridDataModule(args, task.data.data_path) + + # initialize model + lightning_task_class: valid_task_types = task.type.get_class_from_enum() + model = lightning_task_class( + args, # TODO: load model, training, optim separataly + datamodule.node_normalizers, + datamodule.edge_normalizers, + ) + logger.info(f"Loading model weights from {model_spec.model_path}") + state_dict = torch.load(model_spec.model_path) + model.load_state_dict(state_dict) + + # initialize trainer + trainer = Trainer( + accelerator=training_spec_copy.accelerator, + devices=training_spec_copy.devices, + strategy=training_spec_copy.strategy, + log_every_n_steps=training_spec_copy.log_every_n_steps, + # default_root_dir=args.log_dir, + max_epochs=training_spec_copy.epochs, + callbacks=callbacks, + enable_checkpointing=enable_checkpointing, + enable_progress_bar=training_spec_copy.enable_progress_bar, + # deterministic=True, + ) + + logger.info( + f"launch_training {trainer=} {lightning_task_class=} {datamodule=} \ + {run_name=} {experiment_name=} {task.metric=} {storage_uri=} {task.direction=}" + ) + return launch_training( + trainer=trainer, + model=model, + optimizer_spec=optimizer_spec, + datamodule=datamodule, + run_name=run_name, + experiment_name=experiment_name, + metric=task.metric, + storage_uri=storage_uri, + parent_run_id=parent_run_id, + direction=task.direction, + test_models=test_models, + delete_models_after_testing=delete_models_after_testing, + ) + + +def fit_model_with_hparams( + args: Namespace, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + task: TaskSpec, + optimization_space: optimization_space_type, + run_name: str, + experiment_name: str, + storage_uri: str, + parent_run_id: str, + logger: logging.RootLogger, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, + trial: optuna.Trial | None = None, +) -> float: + """ + Generate parameters using the optuna trial from the given parameters. + Then inject these into the given task. + It is important to make sure to not overwrite the task passed in the arguments, or these updates may affect + subsequent trials. + """ + current_hparams: dict[str, int | float | str | bool] = {} + generate_parameters( + OptunaParameterPicker(trial), + current_hparams, + optimization_space, + ) + + training_spec_with_hparams, optimizer_spec_with_hparams = inject_hparams( + training_spec, optimizer_spec, current_hparams + ) + + output = fit_model( + args=args, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=f"{run_name}_{trial.number}", + experiment_name=experiment_name, + storage_uri=storage_uri, + parent_run_id=parent_run_id, + save_models=save_models, + test_models=test_models, + seed=seed, + logger=logger, + trial=trial, + ) # return only the metric value for optuna + + return output diff --git a/gridfm_graphkit/iterate/utils.py b/gridfm_graphkit/iterate/utils.py new file mode 100644 index 0000000..3ea0df8 --- /dev/null +++ b/gridfm_graphkit/iterate/utils.py @@ -0,0 +1,524 @@ +import os +from typing import Any, Dict + +import mlflow +import optuna +import logging +import datetime +import pandas as pd +from mlflow.entities.experiment import Experiment +from gridfm_graphkit.utils.types import ( + optimization_space_type, + TaskSpec, + ParameterBounds, +) + + +# Custom function to parse the optimization space argument +def parse_optimization_space(space: dict | None) -> optimization_space_type | None: + if space is None: + return None + parsed_space: optimization_space_type = {} + for key, value in space.items(): + if isinstance(value, dict): + try: + bounds = ParameterBounds(**value) + parsed_space[key] = bounds + except TypeError: + # Recursively parse nested optimization spaces + parsed_space[key] = parse_optimization_space(value) + elif isinstance(value, list): + # If it's a list, leave it as is + parsed_space[key] = value + else: + raise ValueError(f"Invalid type for {key}: {value}") + return parsed_space + + +def unflatten(dictionary: Dict[str, Any]): + resultDict: Dict = {} + for key, value in dictionary.items(): + parts = key.split(".") + d = resultDict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return resultDict + + +def get_logger(log_level="INFO", log_folder="./experiment_logs") -> logging.RootLogger: + # set up logging file + if not os.path.exists(log_folder): + os.makedirs(log_folder) + current_time = datetime.datetime.now() + current_time = ( + str(current_time).replace(" ", "_").replace(":", "-").replace(".", "-") + ) + log_file = f"{log_folder}/{current_time}" + logger = logging.getLogger() + logger.setLevel(log_level) + handler = logging.FileHandler(log_file) + handler.setLevel(log_level) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logging.basicConfig(level=logging.CRITICAL) + return logger + + +def get_best_validation_metric( + storage_uri: str, + run: mlflow.entities.Run, + metric: str, + direction: str, +): + client = mlflow.tracking.MlflowClient( + tracking_uri=storage_uri, + ) + + if not metric.lower().startswith("val"): + raise Exception( + f"Metric {metric} does not start with `val`. Please choose a validation metric" + ) + for_pd_collect = [] + val_metrics_names = [] + + for metric_name in client.get_run(run.info.run_id).data.metrics: + if metric_name.lower().startswith("val"): + val_metrics_names.append(metric_name) + val_metric_history = client.get_metric_history(run.info.run_id, metric_name) + pd_convertible_metric_history = [ + { + "metric_name": mm.key, + "step": mm.step, + "value": mm.value, + } + for mm in val_metric_history + ] + for_pd_collect += pd_convertible_metric_history + df_val_metrics = pd.DataFrame.from_records(for_pd_collect) + df_val_metrics = df_val_metrics.set_index( + ["metric_name", "step"], verify_integrity=True + ) + series_val_metrics = df_val_metrics["value"] + assert metric in series_val_metrics, ( + f"Error! {metric} is not in {series_val_metrics}" + ) + if direction == "max": + best_step = series_val_metrics[metric].idxmax() + elif direction == "min": + best_step = series_val_metrics[metric].idxmin() + else: + raise Exception( + f"Error! Direction must be either `max` or `min` but got {direction}" + ) + + for val_metric_name in val_metrics_names: + mlflow.log_metric( + f"best_step_{val_metric_name}", + series_val_metrics[(val_metric_name, best_step)], + ) + return series_val_metrics[(metric, best_step)] + + + + +def get_test_metric( + storage_uri: str, + run: mlflow.entities.Run, + metric: str, + direction: str, +): + client = mlflow.tracking.MlflowClient( + tracking_uri=storage_uri, + ) + + if not metric.lower().startswith("test"): + raise Exception( + f"Metric {metric} does not start with `test`. Please choose a test metric" + ) + for_pd_collect = [] + test_metrics_names = [] + + print(f'{client.get_run(run.info.run_id)=}') + metric_value_across_datasets = [] + + for metric_log_name in client.get_run(run.info.run_id).data.metrics: + if "Test - " not in metric_log_name: + continue + dataset_name, metric_name = metric_log_name.replace("Test - ", "").split("/", 1) + if metric == metric_name: + test_metrics_names.append(metric_name) + test_metric_history = client.get_metric_history(run.info.run_id, metric_log_name) + pd_convertible_metric_history = [ + { + "metric_name": mm.key, + "step": mm.step, + "value": mm.value, + } + for mm in test_metric_history + ] + for_pd_collect += pd_convertible_metric_history + df_test_metrics = pd.DataFrame.from_records(for_pd_collect) + metric_value = df_test_metrics.loc[:, 'value'].mean() + return metric_value + + +def check_existing_experiments( + logger: logging.RootLogger, + storage_uri: str, + experiment_name: str, + exp_parent_run_name: str, + task_names: list, + n_trials: int, +) -> Dict[str, Any]: + """ + checks if experiment has been completed (i.e. both task run and nested individual runs are complete) + Args: + logger: logging.RootLogger to save logs to file + storage_uri: folder containing mlflow log data + experiment_name: name of experiment + exp_parent_run_name: run name of the top level experiment run + task_names: list of task names that should be completed + n_trials: number of trials (runs) expected in HPO of each task + Returns: + output: dict with: + no_existing_runs: bool, if True, there are no existing runs + incomplete_run_to_finish: str | None, run id of the experiment run to finish + finished_run: str | None, run id of the finished experiment run + experiment_id: str | None, experiment id it experiment already exists + + """ + client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) + experiment_info = client.get_experiment_by_name(experiment_name) + + output = { + "no_existing_runs": True, + "incomplete_run_to_finish": None, + "finished_run": None, + "experiment_id": None, + } + if experiment_info is None: + return output + + experiment_id = experiment_info.experiment_id + logger.info(f"Checking existing experiment") + logger.info(f"experiment_id: {experiment_id}") + logger.info(f"experiment_name: {experiment_name}") + output["experiment_id"] = experiment_id + experiment_parent_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.runName" LIKE "{exp_parent_run_name}"', + ) + if len(experiment_parent_run_data) >= 1: + logger.info("there is at least one experiment parent run") + finished_run_id = None + incomplete_runs = [] + + # check if one of the runs is complete + for run in experiment_parent_run_data: + ( + completed_task_run_names, + all_tasks_in_experiment_finished, + _, + ) = check_existing_task_parent_runs( + logger=logger, + exp_parent_run_id=run.info.run_id, + storage_uri=storage_uri, + experiment_name=experiment_name, + n_trials=n_trials, + ) + logger.info(f"tasks that should be completed: {task_names}") + logger.info(f"completed_task_run_names: {completed_task_run_names}") + logger.info( + f"all_tasks_in_experiment_finished: {all_tasks_in_experiment_finished}" + ) + all_expected_tasks_completed = [ + item for item in task_names if item in completed_task_run_names + ] + all_expected_tasks_completed = len(task_names) == len( + all_expected_tasks_completed + ) + if all_expected_tasks_completed: + finished_run_id = run.info.run_id + logger.info( + f"The following run FINISHED and will be used for repeated experiments: {finished_run_id}" + ) + else: + incomplete_tasks = [ + item for item in task_names if item not in completed_task_run_names + ] + logger.info( + f"The following run {run.info.run_id} is incomplete, with status {run.info.status} and missing tasks: {incomplete_tasks}" + ) + incomplete_runs.append(run.info.run_id) + + if finished_run_id is not None: + # delete all incomplete runs + delete_nested_experiment_parent_runs( + logger=logger, + delete_runs=incomplete_runs, + experiment_info=experiment_info, + client=client, + leave_one=False, + ) + output["finished_run"] = finished_run_id + output["no_existing_runs"] = False + else: + # delete all incomplete runs, leave one + logger.info(f"incomplete_runs: {incomplete_runs}") + output["incomplete_run_to_finish"] = delete_nested_experiment_parent_runs( + logger=logger, + delete_runs=incomplete_runs, + experiment_info=experiment_info, + client=client, + leave_one=True, + ) + output["no_existing_runs"] = False + return output + + +def delete_nested_experiment_parent_runs( + logger: logging.RootLogger, + delete_runs: list, + experiment_info: mlflow.entities.experiment.Experiment, + client: mlflow.tracking.client.MlflowClient, + leave_one: bool = True, +) -> str | None: + """ + if there are multiple runs for a single experiment, + will delete all runs except the one with the most nested runs (most complete) + Args: + logger: logging.RootLogger to save logs to file + delete_runs: list of runs to delete + experiment_info: info of experiment + client: mlflow client pointing to correct storage uri + leave_one: if True, will not delete the most complete experiment. If False, will delete all experiments + Returns: + run id of the experiment run that was not deleted or None + """ + experiment_id = experiment_info.experiment_id + exp_parent_run_ids = [] + counts = [] + runs_in_experiment = [] + logger.info(f"Deleting from experiment_id:{experiment_id} ") + logger.info(f"delete_runs:{delete_runs} ") + + for exp_parent_run_id in delete_runs: + runs = [] + runs.append(exp_parent_run_id) + task_parent_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{exp_parent_run_id}"', + ) + for task_parent_run in task_parent_run_data: + task_parent_run_id = task_parent_run.info.run_id + runs.append(task_parent_run_id) + individual_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{task_parent_run_id}"', + ) + for individual_run in individual_run_data: + runs.append(individual_run.info.run_id) + exp_parent_run_ids.append(exp_parent_run_id) + counts.append(len(runs)) + runs_in_experiment.append(runs) + + if leave_one and (len(counts) > 0): + index_to_keep = counts.index(max(counts)) + incomplete_run_to_finish = exp_parent_run_ids[index_to_keep] + runs_in_experiment.pop(index_to_keep) + else: + incomplete_run_to_finish = None + + logger.info(f"Deleting runs:{runs_in_experiment} ") + logger.info( + f"experiment_info.artifact_location:{experiment_info.artifact_location}" + ) + for runs in runs_in_experiment: + for run_id in runs: + client.delete_run(run_id) + os.system(f"rm -r {experiment_info.artifact_location}/{run_id}") + return incomplete_run_to_finish + + +def check_existing_task_parent_runs( + logger: logging.RootLogger, + exp_parent_run_id: str, + storage_uri: str, + experiment_name: str, + n_trials: int = 5, +): + """ + checks if tasks have been completed (both task run and nested individual runs are complete) + Args: + logger: logging.RootLogger to save logs to file + exp_parent_run_id: run id of the experiment run being used (top level run id) + storage_uri: folder containing mlflow log data + experiment_name: name of experiment + n_trials: number of trials (runs) expected in HPO of each task + Returns: + complete_task_run_names: list of task names that have been completed + all_tasks_finished: bool showing if all tasks have been completed + task_run_to_id_match: dict matching task names to the task run id + + """ + client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) + experiment_info = client.get_experiment_by_name(experiment_name) + experiment_id = experiment_info.experiment_id + task_parent_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{exp_parent_run_id}"', + ) + complete_task_run_names = [] + all_tasks_finished = [] + # TO DO: make sure we only have one task_parent_run for each name (needed for repeated exps) + task_run_to_id_match = {} + for task_parent_run in task_parent_run_data: + task_run_statuses = [] + task_run_ids = [] + task_run_statuses.append(task_parent_run.info.status) + task_run_ids.append(task_parent_run.info.run_id) + + individual_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{task_parent_run.info.run_id}"', + ) + for individual_run in individual_run_data: + if (individual_run.info.status == "RUNNING") or ( + individual_run.info.status == "FAILED" + ): + continue + task_run_statuses.append(individual_run.info.status) + task_run_ids.append(individual_run.info.run_id) + + task_run_to_id_match[task_parent_run.info.run_name] = ( + task_parent_run.info.run_id + ) + task_run_statuses = list(set(task_run_statuses)) + + condition_1 = len(task_run_statuses) == 1 + condition_2 = task_run_statuses[0] == "FINISHED" + # condition_3 = len(task_run_ids) == (n_trials+1) + if condition_1 and condition_2: # and condition_3: + complete_task_run_names.append(task_parent_run.info.run_name) + task_parent_status = True + else: + task_parent_status = False + all_tasks_finished.append(task_parent_status) + + if all(all_tasks_finished) and (len(all_tasks_finished) > 0): + all_tasks_finished = True + else: + all_tasks_finished = False + complete_task_run_names = list(set(complete_task_run_names)) + return complete_task_run_names, all_tasks_finished, task_run_to_id_match + + +def sync_mlflow_optuna( + optuna_db_path: str, + storage_uri: str, + experiment_name: str, + task_run_id: str | None, + task: TaskSpec, + n_trials: int, + logger: logging.RootLogger, +) -> str | None: + """ + syncs the number of completed trials in mflow and optuna + Args: + optuna_db_path: path to optuna database + storage_uri: path to mlflow storage folder + experiment_name: name on experiment in mlflow + task_run_id: run_id of the task + task: name of the task + logger: logging.RootLogger to save logs to file + Returns: + task_run_id: run id of the task to be continued (if one exists) or None + """ + logger.info( + f"sync_mlflow_optuna - {optuna_db_path=} {storage_uri=} {task_run_id=} {experiment_name=} {task_run_id=}" + ) + # check number of successful mlflow runs in task + client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) + completed_in_mlflow_for_task = [] + all_mlflow_runs_for_task = [] + if task_run_id is not None: + all_mlflow_runs_for_task.append(task_run_id) + logger.info(f"sync_mlflow_optuna - {task_run_id=}") + experiment_info = client.get_experiment_by_name(experiment_name) + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) + individual_run_data = client.search_runs( + experiment_ids=[experiment_info.experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{task_run_id}"', + ) + for individual_run in individual_run_data: + if individual_run.info.status == "FINISHED": + completed_in_mlflow_for_task.append(individual_run.info.run_id) + all_mlflow_runs_for_task.append(individual_run.info.run_id) + + # check number of successful optuna trials in the database + study_names = optuna.study.get_all_study_names( + storage="sqlite:///{}.db".format(optuna_db_path) + ) + if task.name in study_names: + loaded_study = optuna.load_study( + study_name=task.name, storage="sqlite:///{}.db".format(optuna_db_path) + ) + logger.info(f"loaded_study has : {len(loaded_study.trials)} trials") + incomplete = 0 + for trial in loaded_study.trials: + if (trial.state == optuna.trial.TrialState.FAIL) | ( + trial.state == optuna.trial.TrialState.RUNNING + ): + incomplete += 1 + logger.info(f"{incomplete} trials are incomplete") + successful_optuna_trials = len(loaded_study.trials) - incomplete + too_many_trials = successful_optuna_trials > n_trials + no_existing_task = task_run_id is None + optuna_mlflow_mismatch = ( + len(completed_in_mlflow_for_task) != successful_optuna_trials + ) + logger.info( + f"successful optuna trials {successful_optuna_trials} . mlflow runs {len(completed_in_mlflow_for_task)}" + ) + + if too_many_trials or no_existing_task or optuna_mlflow_mismatch: + logger.info(f"deleting study with name {task.name}") + logger.info(f"too_many_trials {too_many_trials}") + logger.info(f"no_existing_task {no_existing_task}") + + # delete optuna study in database + optuna.delete_study( + study_name=task.name, storage="sqlite:///{}.db".format(optuna_db_path) + ) + + # delete any existing mlflow runs + if len(all_mlflow_runs_for_task) > 0: + for item in all_mlflow_runs_for_task: + logger.info(f"deleting {item}") + client.delete_run(item) + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) + os.system(f"rm -r {experiment_info.artifact_location}/{item}") + task_run_id = None + else: + # delete any existing mlflow runs + if len(all_mlflow_runs_for_task) > 0: + for item in all_mlflow_runs_for_task: + logger.info(f"deleting {item}") + client.delete_run(item) + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) + os.system(f"rm -r {experiment_info.artifact_location}/{item}") + task_run_id = None + logging.info(f"sync_mlflow_optuna returns {task_run_id=}") + return task_run_id diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index e69de29..2133993 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -0,0 +1,7 @@ +from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask + + + +__all__ = ( + "FeatureReconstructionTask", +) \ No newline at end of file diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b..cff1f23 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -1,5 +1,4 @@ import torch -from torch.optim.lr_scheduler import ReduceLROnPlateau import lightning as L from pytorch_lightning.utilities import rank_zero_only import numpy as np @@ -45,7 +44,7 @@ class FeatureReconstructionTask(L.LightningModule): predict_step(batch, batch_idx, dataloader_idx=0): Run inference and return denormalized outputs + node masks. configure_optimizers(): - Setup Adam optimizer and ReduceLROnPlateau scheduler. + Setup optimizer and scheduler. on_fit_start(): Save normalization statistics at the beginning of training. on_test_end(): @@ -110,7 +109,7 @@ def on_fit_start(self): ) def shared_step(self, batch): - output = self.forward( + output = self( x=batch.x, pe=batch.pe, edge_index=batch.edge_index, @@ -224,7 +223,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Test loss"] = loss_dict.pop("loss").detach() for metric, value in loss_dict.items(): - metric_name = f"{dataset_name}/{metric}" + metric_name = f"Test - {dataset_name}/{metric}" if "p.u." in metric: # Denormalize metrics expressed in p.u. value *= self.node_normalizers[dataloader_idx].baseMVA @@ -235,7 +234,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): batch_size=batch.num_graphs, add_dataloader_idx=False, sync_dist=True, - logger=False, + logger=True, ) return @@ -283,7 +282,7 @@ def on_test_end(self): pass if "/" in full_key: - dataset_name, metric = full_key.split("/", 1) + dataset_name, metric = full_key.replace("Test - ", "").split("/", 1) if dataset_name not in grouped_metrics: grouped_metrics[dataset_name] = {} grouped_metrics[dataset_name][metric] = value @@ -335,24 +334,29 @@ def on_test_end(self): df.to_csv(csv_path, index=False) def configure_optimizers(self): - self.optimizer = torch.optim.Adam( + if self.args.optimizer.type is None: + self.args.optimizer.type = "Adam" + optimizer = getattr(torch.optim, self.args.optimizer.type) + self.optimizer = optimizer( self.model.parameters(), lr=self.args.optimizer.learning_rate, - betas=(self.args.optimizer.beta1, self.args.optimizer.beta2), + **self.args.optimizer.optimizer_params, #unpack all other optim parameters ) + if self.args.optimizer.scheduler_type is None: + return {"optimizer": self.optimizer} - self.scheduler = ReduceLROnPlateau( + #TODO: add interval handling for scheduler + scheduler = getattr(torch.optim.lr_scheduler, self.args.optimizer.scheduler_type ) + self.scheduler = scheduler( self.optimizer, - mode="min", - factor=self.args.optimizer.lr_decay, - patience=self.args.optimizer.lr_patience, + **self.args.optimizer.scheduler_params ) config_optim = { "optimizer": self.optimizer, "lr_scheduler": { "scheduler": self.scheduler, "monitor": "Validation loss", - "reduce_on_plateau": True, + # "reduce_on_plateau": True, }, } return config_optim diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index e755133..2a1761a 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -1,6 +1,12 @@ from lightning.pytorch.callbacks import Callback from pytorch_lightning.utilities.rank_zero import rank_zero_only from lightning.pytorch.loggers import MLFlowLogger +from lightning.pytorch.callbacks import ( + Timer, + EarlyStopping, + ModelCheckpoint, + LearningRateMonitor, + ) import os import torch @@ -47,3 +53,44 @@ def on_validation_end(self, trainer, pl_module): # Save the model's state_dict model_path = os.path.join(model_dir, self.filename) torch.save(pl_module.state_dict(), model_path) + + + + +def get_training_callbacks(callbacks): + #TODO: make monitored metric configurable + callback_list = [] + if callbacks.tol is not None and callbacks.patience is not None: + early_stop_callback = EarlyStopping( + monitor="Validation loss", + min_delta=callbacks.tol, + patience=callbacks.patience, + verbose=False, + mode="min", + ) + callback_list.append(early_stop_callback) + + save_best_model_callback = SaveBestModelStateDict( + monitor="Validation loss", + mode="min", + filename="best_model_state_dict.pt", + ) + callback_list.append(save_best_model_callback) + + checkpoint_callback = ModelCheckpoint( + monitor="Validation loss", # or whichever metric you track + mode="min", + save_last=True, + save_top_k=0, + ) + callback_list.append(checkpoint_callback) + + if callbacks.max_run_duration is not None: + max_run_callback = Timer(duration=callbacks.max_run_duration) + callback_list.append(max_run_callback) + + if callbacks.monitor_learning_rate: + learning_rate_callback = LearningRateMonitor(logging_interval="epoch") + callback_list.append(learning_rate_callback) + + return callback_list diff --git a/gridfm_graphkit/utils/types.py b/gridfm_graphkit/utils/types.py new file mode 100644 index 0000000..11a119d --- /dev/null +++ b/gridfm_graphkit/utils/types.py @@ -0,0 +1,236 @@ +""" +This module defines all the types expected at input. Used for type checking by jsonargparse. +""" + +from typing import Literal +import enum +from dataclasses import dataclass, field +from typing import Any, Union +from gridfm_graphkit.tasks import ( + FeatureReconstructionTask, +) + + +valid_task_types = type[FeatureReconstructionTask] + +direction_type_to_optuna = {"min": "minimize", "max": "maximize"} + + +@dataclass +class TaskTypeEnum(enum.Enum): + """ + Enum for the type of task to be performed. segmentation, regression or classification. + """ + + feature_reconstruction = "feature_reconstruction" + + def get_class_from_enum( + self, + ) -> valid_task_types: + match self.value: + case TaskTypeEnum.feature_reconstruction.value: + return FeatureReconstructionTask + case _: + raise TypeError("Task type does not exist") + + +class ParameterTypeEnum(enum.Enum): + """ + Enum for the type of parameter allowed in ParameterBounds. integer or real. + """ + + integer = "int" + real = "real" + + +@dataclass +class ParameterBounds: + """ + Dataclass defining a numerical range to search over. + + Args: + min (float | int): Minimum. + max (float | int): Maximum. + type (ParameterTypeEnum): Whether the range is in the space of integers or real numbers. + log (bool): Whether to search over the log space (useful for parameters that vary wildly in scale, e.g. learning rate) + """ + + min: float | int + max: float | int + type: ParameterTypeEnum + log: bool = False + + def __post_init__(self): + if not isinstance(self.type, ParameterTypeEnum): + self.type = ParameterTypeEnum(self.type) + + +optimization_space_type = dict[ + str, Union[list, dict, ParameterBounds, "optimization_space_type"] +] + + +@dataclass +class HyperParameterOptmizerSpec: + """ + Parameters passed to define hyperparameter optimization. Only used with 'iterate' subcommand. + + These parameters are combined with any specified defaults to generate the final task parameters. + + Args: + name (str): Name for this task + type (TaskTypeEnum): Type of task. + terratorch_task (dict): Arguments for the Terratorch Task. + datamodule (BaseDataModule | GeoBenchDataModule): Datamodule to be used. + direction (str): One of min or max. Direction to optimize the metric in. + metric (str): Metric to be optimized. Defaults to "val/loss". + early_prune (bool): Whether to prune unpromising runs early. Defaults to False. + early_stop_patience (int, None): Whether to use Lightning early stopping of runs. Defaults to None, which does not do early stopping. + optimization_except (str[str]): HyperParameters from the optimization space to be ignored for this task. + max_run_duration (str, None): maximum allowed run duration in the form DD:HH:MM:SS; will stop a run after this + amount of time. Defaults to None, which doesn't stop runs by time. + """ + + experiment_name: str + run_name: str + + results_folder: str + save_models: bool = False + n_trials: int = 5 + num_repetitions: int = 2 + repeat_on_best: bool = True + bayesian_search: bool = True + continue_existing_experiment: bool = True + test_models: bool = False + report_on_best_val: bool = True + run_id: str | None = None + optimization_space: dict | None = None + + +@dataclass +class TrainingSpec: + """ + Parameters passed to define lightning trainer + + """ + + batch_size: int + epochs: int + losses: list[str] + loss_weights: list[float] + accelerator: str + devices: str + strategy: str + log_every_n_steps: int = 1 + enable_progress_bar: bool = False + + +@dataclass +class ModelSpec: + """ + Parameters passed to define Model + + """ + + attention_head: int + dropout: float + edge_dim: int + hidden_size: int + input_dim: int + num_layers: int + output_dim: int + pe_dim: int + type: str + model_path: str + + +@dataclass +class OptimizerSpec: + """ + Parameters passed to define Optimization and Scheduling parameters. Learning rate will be overwritten for 'iterate' subcommand. + + """ + + learning_rate: float + type: str + optimizer_params: dict + scheduler_type: str | None + scheduler_params: dict | None + + +@dataclass +class DataSpec: + """ + Parameters passed to define training data. Ignored for 'iterate' subcommand. + + """ + + networks: list[str] + scenarios: list[int] + normalization: str + baseMVA: int + mask_type: str + mask_value: float + mask_ratio: float + mask_dim: int + learn_mask: bool + val_ratio: float + test_ratio: float + workers: int + data_path: str + + +@dataclass +class CallbackSpec: + """ + Parameters passed to define training callbacks + + Args: + patience (int): patience for early stopping + tol (int): ... + + """ + + # TODO: use dicts for each callback type + patience: int | None = None + tol: int | None = None + max_run_duration: int | None = None + monitor_learning_rate: bool = True + optuna_early_prune: bool = False # only processed with iterate command + + +@dataclass +class TaskSpec: + """ + Parameters passed to define each of the tasks. Including DataSpec per task. Only used with 'iterate' subcommand. + + These parameters are combined with any specified defaults to generate the final task parameters. + + Args: + name (str): Name for this task + type (TaskTypeEnum): Type of task. + metric (str): Metric to be optimized. Defaults to "val/loss". + direction (str): One of min or max. Direction to optimize the metric in. + data: datamodule (BaseDataModule | GeoBenchDataModule): Datamodule to be used. + + """ + + name: str + type: TaskTypeEnum = field(repr=False) + data: DataSpec # = field(repr=False) + metric: str = "val/constraint_violations" + direction: Literal["min", "max"] = "min" + + +def recursive_merge(first_dict: dict[str, Any], second_dict: dict[str, Any]): + # consider using deepmerge instead of this + for key, val in second_dict.items(): + if key not in first_dict: + first_dict[key] = val + else: + # if it is a dictionary, recurse deeper + if isinstance(val, dict): + recursive_merge(first_dict[key], val) + # if it is not further nested, just replace the value + else: + first_dict[key] = val diff --git a/pyproject.toml b/pyproject.toml index 51c8665..07d4ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,8 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "jsonargparse>4.0", + "optuna>4.0", ] [project.optional-dependencies] diff --git a/tests/config/iterate_test_case30_ieee_base.yaml b/tests/config/iterate_test_case30_ieee_base.yaml new file mode 100644 index 0000000..7d56216 --- /dev/null +++ b/tests/config/iterate_test_case30_ieee_base.yaml @@ -0,0 +1,73 @@ +seed: 42 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 256 + input_dim: 9 + num_layers: 8 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer + model_path: "gridfm-graphkit/examples/models/GridFM_v0_2.pth" +training: + batch_size: 1 + epochs: 1 + losses: ["MaskedMSE", "PBE"] + loss_weights: [0.01, 0.99] + accelerator: auto + devices: auto + strategy: auto +optimizer: + type: Adam + learning_rate: 0.0001 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 +callbacks: + patience: 100 + tol: 0 + optuna_early_prune: True + +hpo_spec: + experiment_name: GPSTransformer_test_exp + run_name: top_run + optimization_space: + batch_size: [8, 16, 32] + learning_rate: + min: 0.000006 + max: 0.001 + type: real + log: true + n_trials: 1 + bayesian_search: True + results_folder: "results/iterate" + save_models: False + num_repetitions: 1 + repeat_on_best: True + report_on_best_val: True + continue_existing_experiment: True + +tasks: + - name: feature_reconstruction_base + type: feature_reconstruction + metric: "Validation loss" + direction: min + data: + data_path: "data/" + networks: ["case30_ieee"] + scenarios: [1023] + normalization: "baseMVAnorm" + baseMVA: 100 + mask_type: "rnd" + mask_value: 0.0 + mask_ratio: 0.5 + mask_dim: 6 + learn_mask: False + val_ratio: 0.1 + test_ratio: 0.1 + workers: 4 \ No newline at end of file diff --git a/tests/test_iterate.py b/tests/test_iterate.py new file mode 100644 index 0000000..60ed6b6 --- /dev/null +++ b/tests/test_iterate.py @@ -0,0 +1,173 @@ +import itertools +from pathlib import Path + +import yaml +from gridfm_graphkit.__main__ import main +import pytest +import sys + +CONFIG_FILES = [ + "configs/iterate_test_case30_ieee_base.yaml", +] + +MODELS = ["examples/models/GridFM_v0_2.pth"] +INPUT_TEST_MAIN = list(itertools.product(MODELS, CONFIG_FILES)) + + +def get_test_ids() -> list[str]: + test_case_ids = list() + for model, config in INPUT_TEST_MAIN: + # get the filename + model = model.split("/")[-1].replace(".pth", "") + config = config.split("/")[-1].replace(".yaml", "") + # append to list of test ids + test_case_ids.append(f"{config}_{model}") + return test_case_ids + + +def validate_hpo_results( + experiment_name: str, + results_folder: Path, + n_trials: int, + n_tasks: int, + iterate_info: dict, + ): + # check that experiment was created + mlflow_output_path = / "hpo_mlflow_output" + assert mlflow_output_path.exists(), f"Error! Directory does not exist: {mlflow_output_path}" + hpo_exp_path = mlflow_output_path / iterate_info["hpo_experiment_id"] + assert hpo_exp_path.exists(), f"Error! Directory does not exist: {hpo_exp_path}" + meta_yaml_path = hpo_exp_path / "meta.yaml" + assert meta_yaml_path.exists(), ( + f"Error! meta.yaml file {meta_yaml_path} does not exist" + ) + + # open file and check that the experiment name/id is the same + experiment_name_found: bool = False + finished_run_id_found: bool = False + experiment_id_found: bool = False + experiment_id = iterate_info["hpo_experiment_id"] + finished_run_id = iterate_info["hpo_finished_run_id"] + with open(meta_yaml_path, mode="r") as f: + # read all the lines + lines = f.readlines() + # try to find experiment id and name in these lines + for line in lines: + if experiment_name in line: + experiment_name_found = True + if finished_run_id in line: + finished_run_id_found = True + if experiment_id in line: + experiment_id_found = True + assert experiment_name_found and experiment_id_found and finished_run_id_found, ( + f"Error! Both experiment name ({experiment_name=}), finished run id ({finished_run_id=}), \ + and experiment id ({experiment_id}) must be in the {meta_yaml_path=}." + ) + + # check number of runs created + expected_num_runs = (n_trials*n_tasks) + n_tasks + 1 + run_folders = [ f.path for f in os.scandir(folder) if f.is_dir() ] + assert len(run_folders)==expected_num_runs, ( + f"Error! Expected {expected_num_runs} to be created for HPO experiment. Found {len(run_folders)} runs." + ) + + + + +def validate_repeated_results( + experiment_name: str, + results_folder: Path, + n_trials: int, + n_tasks: int, + num_repetitions: int, + iterate_info: dict, + ): + # check that epxeriment was created + mlflow_output_path = / "repeated_mlflow_output" + assert mlflow_output_path.exists(), f"Error! Directory does not exist: {mlflow_output_path}" + hpo_exp_path = mlflow_output_path / iterate_info["hpo_experiment_id"] + assert hpo_exp_path.exists(), f"Error! Directory does not exist: {hpo_exp_path}" + meta_yaml_path = hpo_exp_path / "meta.yaml" + assert meta_yaml_path.exists(), ( + f"Error! meta.yaml file {meta_yaml_path} does not exist" + ) + + # open file and check that the experiment name is the same + experiment_name_found: bool = False + experiment_id_found: bool = False + experiment_id = iterate_info["repeated_experiment_id"] + with open(meta_yaml_path, mode="r") as f: + # read all the lines + lines = f.readlines() + # try to find experiment id and name in these lines + + for line in lines: + if experiment_name in line: + experiment_name_found = True + if experiment_id in line: + experiment_id_found = True + assert experiment_name_found and experiment_id_found, ( + f"Error! Both experiment name ({experiment_name=}) and experiment id ({experiment_id=}) \ + must be in the {meta_yaml_path=}." + ) + + # check number of runs created + expected_num_runs = (n_trials*n_tasks) + 1 + run_folders = [ f.path for f in os.scandir(folder) if f.is_dir() ] + assert len(run_folders)==expected_num_runs, ( + f"Error! Expected {expected_num_runs} to be created for repeated experiment. Found {len(run_folders)} runs." + ) + + + + + +@pytest.mark.parametrize( + "model, config", + INPUT_TEST_MAIN, + ids=get_test_ids(), +) +def test_iterate( + model: str, + config: str, +): + test_dir = Path(__file__).parents[0] + home_dir = test_dir.parents[0] + config_file: Path = test_dir / config + assert config_file.exists() + with open(config_file, "r") as file: + config_data = yaml.safe_load(file) + experiment_name = config_data["hpo_spec"]["experiment_name"] + results_folder = config_data["hpo_spec"]["results_folder"] + results_folder = Path(results_folder) + num_repetitions = config_data["hpo_spec"]["num_repetitions"] + n_trials = config_data["hpo_spec"]["n_trials"] + n_tasks = len(config_data["tasks"]) + + + #send command to sys + arguments = ["gridfm_graphkit", "iterate", "--config", str(config_file.resolve())] + model_path = home_dir / model + results_folder = home_dir / "test_reults" + arguments.append["--model.model_path", f"{str(model_path)}"] + arguments.append["--hpo_spec.results_folder", f"{str(results_folder)}"] + + sys.argv = arguments + iterate_info = main() + assert isinstance(iterate_info, dict), f"Error! {iterate_info=} is not a dict" + validate_hpo_results( + experiment_name=experiment_name, + results_folder=results_folder, + n_trials=n_trials, + n_tasks=n_tasks, + iterate_info=iterate_info, + ) + + validate_repeated_results( + experiment_name=experiment_name, + results_folder=results_folder, + num_repetitions=num_repetitions, + n_trials=n_trials, + n_tasks=n_tasks, + iterate_info=iterate_info, + ) \ No newline at end of file