From 422f79305accc7bae44141c2937c9731ef3d6b6d Mon Sep 17 00:00:00 2001 From: Naomi Simumba Date: Tue, 23 Dec 2025 13:50:28 -0500 Subject: [PATCH 01/16] update examples Signed-off-by: Naomi Simumba --- examples/config/case118_ieee_base.yaml | 12 ++- examples/config/case240_pserc_base.yaml | 12 ++- examples/config/case24_ieee_rts_base.yaml | 12 ++- examples/config/case300_ieee_base.yaml | 12 ++- examples/config/case30_ieee_base.yaml | 12 ++- examples/config/case30_ieee_base_hpo.yaml | 86 +++++++++++++++++++++ examples/config/case39_epri_base.yaml | 12 ++- examples/config/case57_ieee_base.yaml | 12 ++- examples/config/case89_pegase_base.yaml | 12 ++- examples/config/gridFMv0.1_pretraining.yaml | 12 ++- examples/config/gridFMv0.2_pretraining.yaml | 12 ++- 11 files changed, 166 insertions(+), 40 deletions(-) create mode 100644 examples/config/case30_ieee_base_hpo.yaml diff --git a/examples/config/case118_ieee_base.yaml b/examples/config/case118_ieee_base.yaml index 41de315..dec4efa 100644 --- a/examples/config/case118_ieee_base.yaml +++ b/examples/config/case118_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case240_pserc_base.yaml b/examples/config/case240_pserc_base.yaml index 6758061..4ee4b82 100644 --- a/examples/config/case240_pserc_base.yaml +++ b/examples/config/case240_pserc_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case24_ieee_rts_base.yaml b/examples/config/case24_ieee_rts_base.yaml index b912b46..e2540db 100644 --- a/examples/config/case24_ieee_rts_base.yaml +++ b/examples/config/case24_ieee_rts_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case300_ieee_base.yaml b/examples/config/case300_ieee_base.yaml index d2fe4c4..1717bae 100644 --- a/examples/config/case300_ieee_base.yaml +++ b/examples/config/case300_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case30_ieee_base.yaml b/examples/config/case30_ieee_base.yaml index a884933..cece70e 100644 --- a/examples/config/case30_ieee_base.yaml +++ b/examples/config/case30_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case30_ieee_base_hpo.yaml b/examples/config/case30_ieee_base_hpo.yaml new file mode 100644 index 0000000..f0ba5cd --- /dev/null +++ b/examples/config/case30_ieee_base_hpo.yaml @@ -0,0 +1,86 @@ +seed: 42 +# data: +# networks: ["case30_ieee"] +# scenarios: [1023] +# normalization: "baseMVAnorm" +# baseMVA: 100 +# mask_type: "rnd" +# mask_value: 0.0 +# mask_ratio: 0.5 +# mask_dim: 6 +# learn_mask: False +# val_ratio: 0.1 +# test_ratio: 0.1 +# workers: 4 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 256 + input_dim: 9 + num_layers: 8 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer + model_path: "/dccstor/sentinel1/nsimumba/neso_gridfm/gridfm-graphkit/examples/models/GridFM_v0_2.pth" +training: + batch_size: 1 + epochs: 2 + losses: ["MaskedMSE", "PBE"] + loss_weights: [0.01, 0.99] + accelerator: auto + devices: auto + strategy: auto +optimizer: + type: Adam + learning_rate: 0.0001 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 +callbacks: + patience: 100 + tol: 0 + optuna_early_prune: True + +hpo_spec: + experiment_name: 1#GPSTransformer_test_exp + run_name: top_run + optimization_space: + batch_size: [8, 16, 32] + learning_rate: + min: 0.000006 + max: 0.001 + type: real + log: true + n_trials: 2 + bayesian_search: True + results_folder: "/dccstor/sentinel1/nsimumba/neso_gridfm/results/iterate" + save_models: False + num_repetitions: 2 + repeat_on_best: True + report_on_best_val: True + continue_existing_experiment: True + +tasks: #cannot overwrite any parameters here + - name: feature_reconstruction_base + type: feature_reconstruction + metric: "Validation loss" + direction: min + data: + data_path: "/dccstor/gridfm/PowerGraph_TP" + networks: ["case30_ieee"] + scenarios: [1023] + normalization: "baseMVAnorm" + baseMVA: 100 + mask_type: "rnd" + mask_value: 0.0 + mask_ratio: 0.5 + mask_dim: 6 + learn_mask: False + val_ratio: 0.1 + test_ratio: 0.1 + workers: 4 \ No newline at end of file diff --git a/examples/config/case39_epri_base.yaml b/examples/config/case39_epri_base.yaml index 076a25a..3c3747c 100644 --- a/examples/config/case39_epri_base.yaml +++ b/examples/config/case39_epri_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case57_ieee_base.yaml b/examples/config/case57_ieee_base.yaml index 96e7c3c..172c3f5 100644 --- a/examples/config/case57_ieee_base.yaml +++ b/examples/config/case57_ieee_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/case89_pegase_base.yaml b/examples/config/case89_pegase_base.yaml index 0eef554..8175c7d 100644 --- a/examples/config/case89_pegase_base.yaml +++ b/examples/config/case89_pegase_base.yaml @@ -31,11 +31,15 @@ training: devices: auto strategy: auto optimizer: + type: Adam learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 callbacks: patience: 100 tol: 0 diff --git a/examples/config/gridFMv0.1_pretraining.yaml b/examples/config/gridFMv0.1_pretraining.yaml index c489e95..51ec4f6 100644 --- a/examples/config/gridFMv0.1_pretraining.yaml +++ b/examples/config/gridFMv0.1_pretraining.yaml @@ -37,11 +37,15 @@ model: pe_dim: 20 type: GNN_TransformerConv optimizer: - beta1: 0.9 - beta2: 0.999 + type: Adam learning_rate: 1.0e-05 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 seed: 200 training: batch_size: 64 diff --git a/examples/config/gridFMv0.2_pretraining.yaml b/examples/config/gridFMv0.2_pretraining.yaml index 8a804db..9fc5265 100644 --- a/examples/config/gridFMv0.2_pretraining.yaml +++ b/examples/config/gridFMv0.2_pretraining.yaml @@ -37,11 +37,15 @@ model: pe_dim: 20 type: GPSTransformer optimizer: - beta1: 0.9 - beta2: 0.999 + type: Adam learning_rate: 0.0001 - lr_decay: 0.7 - lr_patience: 10 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 seed: 0 training: batch_size: 64 From 5447dc8720ff545e11415128bfc784a42a46b5eb Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:08:07 -0500 Subject: [PATCH 02/16] switch to jsonargparse Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/__main__.py | 70 +++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index f2e3d62..d7b59fb 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -1,18 +1,25 @@ -import argparse from datetime import datetime -from gridfm_graphkit.cli import main_cli +from gridfm_graphkit.cli import main_cli, iterate_cli +from jsonargparse import ArgumentParser, Namespace -def main(): - parser = argparse.ArgumentParser( - prog="gridfm_graphkit", - description="gridfm-graphkit CLI", +from gridfm_graphkit.utils.types import ( + HyperParameterOptmizerSpec, TaskSpec, CallbackSpec, + OptimizerSpec, ModelSpec, TrainingSpec, DataSpec ) - subparsers = parser.add_subparsers(dest="command", required=True) + + +def main(): + # parser = argparse.ArgumentParser( + # prog="gridfm_graphkit", + # description="gridfm-graphkit CLI", + # ) + # subparsers = parser.add_subparsers(dest="command", required=True) exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # ---- TRAIN SUBCOMMAND ---- - train_parser = subparsers.add_parser("train", help="Run training") + # train_parser = subparsers.add_parser("train", help="Run training") + train_parser = ArgumentParser() train_parser.add_argument("--config", type=str, required=True) train_parser.add_argument("--exp_name", type=str, default=exp_name) train_parser.add_argument("--run_name", type=str, default="run") @@ -20,7 +27,8 @@ def main(): train_parser.add_argument("--data_path", type=str, default="data") # ---- FINETUNE SUBCOMMAND ---- - finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") + # finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") + finetune_parser = ArgumentParser() finetune_parser.add_argument("--config", type=str, required=True) finetune_parser.add_argument("--model_path", type=str, required=True) finetune_parser.add_argument("--exp_name", type=str, default=exp_name) @@ -28,20 +36,9 @@ def main(): finetune_parser.add_argument("--log_dir", type=str, default="mlruns") finetune_parser.add_argument("--data_path", type=str, default="data") - # ---- EVALUATE SUBCOMMAND ---- - evaluate_parser = subparsers.add_parser( - "evaluate", - help="Evaluate model performance", - ) - evaluate_parser.add_argument("--model_path", type=str, default=None) - evaluate_parser.add_argument("--config", type=str, required=True) - evaluate_parser.add_argument("--exp_name", type=str, default=exp_name) - evaluate_parser.add_argument("--run_name", type=str, default="run") - evaluate_parser.add_argument("--log_dir", type=str, default="mlruns") - evaluate_parser.add_argument("--data_path", type=str, default="data") - # ---- PREDICT SUBCOMMAND ---- - predict_parser = subparsers.add_parser("predict", help="Evaluate model performance") + # predict_parser = subparsers.add_parser("predict", help="Evaluate model performance") + predict_parser = ArgumentParser() predict_parser.add_argument("--model_path", type=str, required=None) predict_parser.add_argument("--config", type=str, required=True) predict_parser.add_argument("--exp_name", type=str, default=exp_name) @@ -50,8 +47,35 @@ def main(): predict_parser.add_argument("--data_path", type=str, default="data") predict_parser.add_argument("--output_path", type=str, default="data") + # ---- ITERATE SUBCOMMAND ---- + # iterate_parser = subparsers.add_parser("iterate", help="Run model benchmarking") + iterate_parser = ArgumentParser() + iterate_parser.add_argument("--seed", type=int) + iterate_parser.add_argument("--hpo_spec", type=HyperParameterOptmizerSpec) + iterate_parser.add_argument("--tasks", type=list[TaskSpec]) + iterate_parser.add_argument("--model", type=ModelSpec) + iterate_parser.add_argument("--optimizer", type=OptimizerSpec) + iterate_parser.add_argument("--training", type=TrainingSpec) + iterate_parser.add_argument("--callbacks", type=CallbackSpec) + iterate_parser.add_argument("--config", action="config") + + parser = ArgumentParser( + prog="gridfm_graphkit", + description="gridfm-graphkit CLI", + ) + subcommands = parser.add_subcommands() + subcommands.add_subcommand('train', train_parser) + subcommands.add_subcommand('finetune', finetune_parser) + subcommands.add_subcommand('predict', predict_parser) + subcommands.add_subcommand('iterate', iterate_parser) + args = parser.parse_args() - main_cli(args) + if args.subcommand == "iterate": + # config = args.iterate.config + # config_args: Namespace = iterate_parser.instantiate_classes(config) + iterate_cli(args.iterate) + else: + main_cli(args) if __name__ == "__main__": From 2bb7cf73d88b875f9975c69c4dbf3212b09a88e4 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:09:17 -0500 Subject: [PATCH 03/16] export tasks Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/tasks/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index e69de29..2133993 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -0,0 +1,7 @@ +from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask + + + +__all__ = ( + "FeatureReconstructionTask", +) \ No newline at end of file From 67daf106fc52a58753c39be9893c62c49201ca5a Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:09:53 -0500 Subject: [PATCH 04/16] reformat callbacks Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/training/callbacks.py | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index e755133..2a1761a 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -1,6 +1,12 @@ from lightning.pytorch.callbacks import Callback from pytorch_lightning.utilities.rank_zero import rank_zero_only from lightning.pytorch.loggers import MLFlowLogger +from lightning.pytorch.callbacks import ( + Timer, + EarlyStopping, + ModelCheckpoint, + LearningRateMonitor, + ) import os import torch @@ -47,3 +53,44 @@ def on_validation_end(self, trainer, pl_module): # Save the model's state_dict model_path = os.path.join(model_dir, self.filename) torch.save(pl_module.state_dict(), model_path) + + + + +def get_training_callbacks(callbacks): + #TODO: make monitored metric configurable + callback_list = [] + if callbacks.tol is not None and callbacks.patience is not None: + early_stop_callback = EarlyStopping( + monitor="Validation loss", + min_delta=callbacks.tol, + patience=callbacks.patience, + verbose=False, + mode="min", + ) + callback_list.append(early_stop_callback) + + save_best_model_callback = SaveBestModelStateDict( + monitor="Validation loss", + mode="min", + filename="best_model_state_dict.pt", + ) + callback_list.append(save_best_model_callback) + + checkpoint_callback = ModelCheckpoint( + monitor="Validation loss", # or whichever metric you track + mode="min", + save_last=True, + save_top_k=0, + ) + callback_list.append(checkpoint_callback) + + if callbacks.max_run_duration is not None: + max_run_callback = Timer(duration=callbacks.max_run_duration) + callback_list.append(max_run_callback) + + if callbacks.monitor_learning_rate: + learning_rate_callback = LearningRateMonitor(logging_interval="epoch") + callback_list.append(learning_rate_callback) + + return callback_list From 60e2f9ffc312b457597eff591911651ead386bd8 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:10:31 -0500 Subject: [PATCH 05/16] update requirements Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 51c8665..07d4ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,8 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "jsonargparse>4.0", + "optuna>4.0", ] [project.optional-dependencies] From cda22de21e9897a9c947c3fb1433ee33586c5c1d Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:11:27 -0500 Subject: [PATCH 06/16] configurable optim Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- .../tasks/feature_reconstruction_task.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b..d156c4d 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -335,24 +335,29 @@ def on_test_end(self): df.to_csv(csv_path, index=False) def configure_optimizers(self): - self.optimizer = torch.optim.Adam( + if self.args.optimizer.type is None: + self.args.optimizer.type = "Adam" + optimizer = getattr(torch.optim, self.args.optimizer.type) + self.optimizer = optimizer( self.model.parameters(), lr=self.args.optimizer.learning_rate, - betas=(self.args.optimizer.beta1, self.args.optimizer.beta2), + **self.args.optimizer.optimizer_params, #unpack all other optim parameters ) + if self.args.optimizer.scheduler_type is None: + return {"optimizer": self.optimizer} - self.scheduler = ReduceLROnPlateau( + #TODO: add interval handling for scheduler + scheduler = getattr(torch.optim.lr_scheduler, self.args.optimizer.scheduler_type ) + self.scheduler = scheduler( self.optimizer, - mode="min", - factor=self.args.optimizer.lr_decay, - patience=self.args.optimizer.lr_patience, + **self.args.optimizer.scheduler_params ) config_optim = { "optimizer": self.optimizer, "lr_scheduler": { "scheduler": self.scheduler, "monitor": "Validation loss", - "reduce_on_plateau": True, + # "reduce_on_plateau": True, }, } return config_optim From 13805e50caa4c381469f6fd12a386f19cd040909 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:13:31 -0500 Subject: [PATCH 07/16] add input classes Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/utils/types.py | 257 +++++++++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 gridfm_graphkit/utils/types.py diff --git a/gridfm_graphkit/utils/types.py b/gridfm_graphkit/utils/types.py new file mode 100644 index 0000000..5a29b49 --- /dev/null +++ b/gridfm_graphkit/utils/types.py @@ -0,0 +1,257 @@ +""" +This module defines all the types expected at input. Used for type checking by jsonargparse. +""" + +from ast import Dict +from typing import Literal +import copy +import enum +from dataclasses import dataclass, field, replace +from typing import Any, Optional, Union +from gridfm_graphkit.tasks import ( + FeatureReconstructionTask, + # ContingencyAnalysisTask, +) +from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule + + +import logging + + +valid_task_types = type[ + FeatureReconstructionTask + # | ContingencyAnalysisTask +] + +direction_type_to_optuna = {"min": "minimize", "max": "maximize"} + + + +@dataclass +class TaskTypeEnum(enum.Enum): + """ + Enum for the type of task to be performed. segmentation, regression or classification. + """ + + feature_reconstruction = "feature_reconstruction" + # contingency_analysis = "contingency_analysis" + + def get_class_from_enum( + self, + ) -> valid_task_types: + match self.value: + case TaskTypeEnum.feature_reconstruction.value: + return FeatureReconstructionTask + case TaskTypeEnum.contingency_analysis.value: + return ContingencyAnalysisTask + case _: + raise TypeError("Task type does not exist") + + +class ParameterTypeEnum(enum.Enum): + """ + Enum for the type of parameter allowed in ParameterBounds. integer or real. + """ + + integer = "int" + real = "real" + + +@dataclass +class ParameterBounds: + """ + Dataclass defining a numerical range to search over. + + Args: + min (float | int): Minimum. + max (float | int): Maximum. + type (ParameterTypeEnum): Whether the range is in the space of integers or real numbers. + log (bool): Whether to search over the log space (useful for parameters that vary wildly in scale, e.g. learning rate) + """ + + min: float | int + max: float | int + type: ParameterTypeEnum + log: bool = False + + def __post_init__(self): + if not isinstance(self.type, ParameterTypeEnum): + self.type = ParameterTypeEnum(self.type) + + + +optimization_space_type = dict[ + str, Union[list, dict, ParameterBounds, "optimization_space_type"] +] + + + + + +@dataclass +class HyperParameterOptmizerSpec: + """ + Parameters passed to define hyperparameter optimization. Only used with 'iterate' subcommand. + + These parameters are combined with any specified defaults to generate the final task parameters. + + Args: + name (str): Name for this task + type (TaskTypeEnum): Type of task. + terratorch_task (dict): Arguments for the Terratorch Task. + datamodule (BaseDataModule | GeoBenchDataModule): Datamodule to be used. + direction (str): One of min or max. Direction to optimize the metric in. + metric (str): Metric to be optimized. Defaults to "val/loss". + early_prune (bool): Whether to prune unpromising runs early. Defaults to False. + early_stop_patience (int, None): Whether to use Lightning early stopping of runs. Defaults to None, which does not do early stopping. + optimization_except (str[str]): HyperParameters from the optimization space to be ignored for this task. + max_run_duration (str, None): maximum allowed run duration in the form DD:HH:MM:SS; will stop a run after this + amount of time. Defaults to None, which doesn't stop runs by time. + """ + experiment_name: str + run_name: str + + results_folder: str + save_models: bool = False + n_trials: int = 5 + num_repetitions: int = 2 + repeat_on_best: bool = True + bayesian_search: bool = True + continue_existing_experiment: bool = True + test_models: bool = False + report_on_best_val: bool = True + run_id: str | None = None + optimization_space: dict | None = None + + + + +@dataclass +class TrainingSpec: + """ + Parameters passed to define lightning trainer + + """ + batch_size: int + epochs: int + losses: list[str] + loss_weights: list[float] + accelerator: str + devices: str + strategy: str + log_every_n_steps: int = 1 + enable_progress_bar: bool = False + + + + +@dataclass +class ModelSpec: + """ + Parameters passed to define Model + + """ + attention_head: int + dropout: float + edge_dim: int + hidden_size: int + input_dim: int + num_layers: int + output_dim: int + pe_dim: int + type: str + model_path: str + + + + +@dataclass +class OptimizerSpec: + """ + Parameters passed to define Optimization and Scheduling parameters. Learning rate will be overwritten for 'iterate' subcommand. + + """ + learning_rate: float + type: str + optimizer_params: dict + scheduler_type: str | None + scheduler_params: dict | None + + + +@dataclass +class DataSpec: + """ + Parameters passed to define training data. Ignored for 'iterate' subcommand. + + """ + networks: list[str] + scenarios: list[int] + normalization: str + baseMVA: int + mask_type: str + mask_value: float + mask_ratio: float + mask_dim: int + learn_mask: bool + val_ratio: float + test_ratio: float + workers: int + data_path: str + + + +@dataclass +class CallbackSpec: + """ + Parameters passed to define training callbacks + + Args: + patience (int): patience for early stopping + tol (int): ... + + """ + #TODO: use dicts for each callback type + patience: int | None = None + tol: int | None = None + max_run_duration: int | None = None + monitor_learning_rate: bool = True + optuna_early_prune: bool = False #only processed with iterate command + + + + +@dataclass +class TaskSpec: + """ + Parameters passed to define each of the tasks. Including DataSpec per task. Only used with 'iterate' subcommand. + + These parameters are combined with any specified defaults to generate the final task parameters. + + Args: + name (str): Name for this task + type (TaskTypeEnum): Type of task. + metric (str): Metric to be optimized. Defaults to "val/loss". + direction (str): One of min or max. Direction to optimize the metric in. + data: datamodule (BaseDataModule | GeoBenchDataModule): Datamodule to be used. + + """ + name: str + type: TaskTypeEnum = field(repr=False) + data: DataSpec # = field(repr=False) + metric: str = "val/constraint_violations" + direction: Literal["min", "max"] = "min" + + +def recursive_merge(first_dict: dict[str, Any], second_dict: dict[str, Any]): + # consider using deepmerge instead of this + for key, val in second_dict.items(): + if key not in first_dict: + first_dict[key] = val + else: + # if it is a dictionary, recurse deeper + if isinstance(val, dict): + recursive_merge(first_dict[key], val) + # if it is not further nested, just replace the value + else: + first_dict[key] = val \ No newline at end of file From 4927666e209168a509470f54b7f9df3b4e27c6e1 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:15:18 -0500 Subject: [PATCH 08/16] add hpo cli fn Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/cli.py | 91 +++++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a7507c1..a9669d7 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -1,6 +1,6 @@ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule from gridfm_graphkit.io.param_handler import NestedNamespace -from gridfm_graphkit.training.callbacks import SaveBestModelStateDict +from gridfm_graphkit.training.callbacks import get_training_callbacks import numpy as np import os import yaml @@ -8,37 +8,12 @@ import random import pandas as pd -from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask -from lightning.pytorch.callbacks.early_stopping import EarlyStopping -from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from gridfm_graphkit.tasks import FeatureReconstructionTask + from lightning.pytorch.loggers import MLFlowLogger import lightning as L -def get_training_callbacks(args): - early_stop_callback = EarlyStopping( - monitor="Validation loss", - min_delta=args.callbacks.tol, - patience=args.callbacks.patience, - verbose=False, - mode="min", - ) - - save_best_model_callback = SaveBestModelStateDict( - monitor="Validation loss", - mode="min", - filename="best_model_state_dict.pt", - ) - - checkpoint_callback = ModelCheckpoint( - monitor="Validation loss", # or whichever metric you track - mode="min", - save_last=True, - save_top_k=0, - ) - - return [early_stop_callback, save_best_model_callback, checkpoint_callback] - def main_cli(args): logger = MLFlowLogger( @@ -47,6 +22,9 @@ def main_cli(args): run_name=args.run_name, ) + subcommand = args.subcommand + args = args[subcommand] + with open(args.config, "r") as f: base_config = yaml.safe_load(f) @@ -62,7 +40,7 @@ def main_cli(args): litGrid.node_normalizers, litGrid.edge_normalizers, ) - if args.command != "train": + if subcommand != "train": print(f"Loading model weights from {args.model_path}") state_dict = torch.load(args.model_path) model.load_state_dict(state_dict) @@ -75,15 +53,15 @@ def main_cli(args): log_every_n_steps=1, default_root_dir=args.log_dir, max_epochs=config_args.training.epochs, - callbacks=get_training_callbacks(config_args), + callbacks=get_training_callbacks(config_args.callbacks), ) - if args.command == "train" or args.command == "finetune": + if subcommand == "train" or subcommand == "finetune": trainer.fit(model=model, datamodule=litGrid) - if args.command != "predict": + if subcommand != "predict": trainer.test(model=model, datamodule=litGrid) - if args.command == "predict": + if subcommand == "predict": predictions = trainer.predict(model=model, datamodule=litGrid) all_outputs = [] all_scenarios = [] @@ -120,3 +98,50 @@ def main_cli(args): df.to_csv(csv_path, index=False) print(f"Saved predictions to {csv_path}") + + + + + +from gridfm_graphkit.utils.types import ( + HyperParameterOptmizerSpec, TaskSpec, CallbackSpec, + OptimizerSpec, ModelSpec, TrainingSpec, DataSpec + ) + +from gridfm_graphkit.iterate import run_iterate_experiments + +from jsonargparse import Namespace + + +DEFAULT_SEED = 42 + + + + + + + +def iterate_cli(config_args): + #validate inputs + if config_args.seed is not None: + assert isinstance(config_args.seed, int), ( + "seed must be an integer" + ) + else: + config_init.seed = DEFAULT_SEED + torch.manual_seed(config_args.seed) + random.seed(config_args.seed) + np.random.seed(config_args.seed) + + + run_iterate_experiments( + args=config_args, #TODO + model_spec=config_args.model, + training_spec=config_args.training, + optimizer_spec=config_args.optimizer, + callbacks_spec=config_args.callbacks, + hpo_spec=config_args.hpo_spec, + tasks=config_args.tasks, + seed=config_args.seed, + ) + From 7df86941cf208c0d3fbbe54a3f020ffbae722bf7 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:36:49 -0500 Subject: [PATCH 09/16] add iterator Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/iterate/__init__.py | 13 + gridfm_graphkit/iterate/hpo.py | 800 +++++++++++++++++++++++ gridfm_graphkit/iterate/model_fitting.py | 448 +++++++++++++ gridfm_graphkit/iterate/utils.py | 441 +++++++++++++ 4 files changed, 1702 insertions(+) create mode 100644 gridfm_graphkit/iterate/__init__.py create mode 100644 gridfm_graphkit/iterate/hpo.py create mode 100644 gridfm_graphkit/iterate/model_fitting.py create mode 100644 gridfm_graphkit/iterate/utils.py diff --git a/gridfm_graphkit/iterate/__init__.py b/gridfm_graphkit/iterate/__init__.py new file mode 100644 index 0000000..cf52487 --- /dev/null +++ b/gridfm_graphkit/iterate/__init__.py @@ -0,0 +1,13 @@ +from gridfm_graphkit.iterate.hpo import ( + run_hpo_experiments, + run_repeated_experiments, + run_iterate_experiments) +# from gridfm_graphkit.tasks.contingency_analysis import ContingencyAnalysisTask + + + +__all__ = ( + "run_hpo_experiments", + "run_repeated_experiments", + "run_iterate_experiments" +) \ No newline at end of file diff --git a/gridfm_graphkit/iterate/hpo.py b/gridfm_graphkit/iterate/hpo.py new file mode 100644 index 0000000..d55ab97 --- /dev/null +++ b/gridfm_graphkit/iterate/hpo.py @@ -0,0 +1,800 @@ +import os +from pathlib import Path +import logging +from jsonargparse import Namespace + + + + + +from functools import partial +from typing import Any, Dict + +import mlflow +import optuna +import pandas as pd +import torch +from optuna.pruners import HyperbandPruner +from optuna.samplers import BaseSampler, RandomSampler + +from gridfm_graphkit.iterate.model_fitting import fit_model, fit_model_with_hparams + +from gridfm_graphkit.utils.types import ( + HyperParameterOptmizerSpec, + TaskSpec, + CallbackSpec, + OptimizerSpec, + ModelSpec, + TrainingSpec, + DataSpec, + direction_type_to_optuna, + optimization_space_type + + ) + +from gridfm_graphkit.iterate.utils import ( + parse_optimization_space, + check_existing_task_parent_runs, + check_existing_experiments, + unflatten, + get_logger, + sync_mlflow_optuna, + ) + + + +def run_iterate_experiments( + args, #TODO: remove + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int = 42, + ) -> bool: + """ + runs full benchmarking (hpo + repeated) for a model across multiple tasks + + Args: + + + Return: + + """ + #create folders and initialize logger + base = Path(args.hpo_spec.results_folder) + HPO_EXP_FOLDER = base / "hpo_mlflow_output" + REPEATED_EXP_FOLDER = base / "repeated_mlflow_output" + REPEATED_CSV_FOLDER = base / "repeated_csv_output" + LOG_FOLDER = base / "logs" + folders = [HPO_EXP_FOLDER, REPEATED_EXP_FOLDER, REPEATED_CSV_FOLDER, LOG_FOLDER] + for f in folders: + os.makedirs(str(f), exist_ok=True) + logger = get_logger(log_folder=str(LOG_FOLDER)) + + benchmarking_completed = False + try: + # run hpo on model across multiple tasks + hpo_output = run_hpo_experiments( + args=args, #TODO: remove args + logger=logger, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + hpo_spec=hpo_spec, + tasks=tasks, + seed=seed, + storage_uri=HPO_EXP_FOLDER, + ) + + if args.hpo_spec.num_repetitions >= 1: + # run repeated experiments + run_repeated_experiments( + args=args, #TODO: remove args + logger=logger, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + hpo_spec=hpo_spec, + tasks=tasks, + seed=seed, + repeated_storage_uri=REPEATED_EXP_FOLDER, + hpo_storage_uri=HPO_EXP_FOLDER, + csv_folder=REPEATED_CSV_FOLDER, + experiment_id=hpo_output["experiment_id"], + parent_run_id=hpo_output["finished_run_id"], + ) + except Exception as e: + logger.info(f"Could not complete due to error {e}") + raise + + + +def run_hpo_experiments( + args: Namespace, #TODO: remove + logger: logging.RootLogger, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int, + storage_uri: Path, +) -> Dict[str, str]: + """Highest level function to run hpo only for a model across multiple tasks + + Args: + args: Namespace, #TODO: remove + logger: logging.RootLogger, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int, + storage_uri: Path, + + + Return: + + + """ + # https://mlflow.org/docs/latest/ml/tracking/system-metrics/#using-the-environment-variable-to-control-system-metrics-logging + if os.getenv("MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING") is None: + os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" + + model_type: str = model_spec.type + run_id: str = hpo_spec.run_id + experiment_name: str = hpo_spec.experiment_name + task_names = [task.name for task in tasks] + run_name = f"top_run_{hpo_spec.experiment_name}" if hpo_spec.run_name is None else hpo_spec.run_name + optimization_space = parse_optimization_space(hpo_spec.optimization_space) + completed_task_run_names = [] + optimize_hyperparams = True + task_run_to_id_match = {} + + storage_uri = str(storage_uri) + logger.info(f"Setting tracking URI: {storage_uri}") + mlflow.set_tracking_uri(storage_uri) + logger.info(f"Setting experiment name: {experiment_name}") + mlflow.set_experiment(experiment_name) + experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id + + if hpo_spec.continue_existing_experiment: + # find status of existing runs, and delete incomplete runs except one with the most complete tasks + existing_experiments = check_existing_experiments( + logger=logger, + storage_uri=storage_uri, + experiment_name=experiment_name, + exp_parent_run_name=run_name, + task_names=task_names, + n_trials=hpo_spec.n_trials, + ) + if existing_experiments["no_existing_runs"]: + logger.info("\nStarting new experiment from scratch") + else: + if (existing_experiments["incomplete_run_to_finish"] is not None) and ( + run_id is None + ): + logger.info("Continuing previous experiment parent run") + run_id = existing_experiments["incomplete_run_to_finish"] + logger.debug(f"incomplete_run_to_finish: {run_id=}") + experiment_id = existing_experiments["experiment_id"] + optimize_hyperparams = True + + if existing_experiments["finished_run"] is not None: + optimize_hyperparams = False + finished_run_id = existing_experiments["finished_run"] + logger.debug(f"finished_run: {run_id=}") + run_id = existing_experiments["finished_run"] + + # get previously completed tasks + completed_task_run_names, _, task_run_to_id_match = ( + check_existing_task_parent_runs( + logger, run_id, storage_uri, experiment_name, hpo_spec.n_trials + ) + ) + else: + logger.info("Starting new experiment from scratch") + + # only run hyperparameter optimization (HPO) if there are no experiments with finished HPO + if optimize_hyperparams: + if hpo_spec.bayesian_search: + sampler = None # defaults to TPESampler + else: + sampler = RandomSampler() + experiment_id, finished_run_id = _run_hpo( + args=args, #TODO + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + run_name=run_name, + run_id=run_id, + tasks=tasks, + task_names=task_names, + completed_task_run_names=completed_task_run_names, + task_run_to_id_match=task_run_to_id_match, + storage_uri=storage_uri, + experiment_name=experiment_name, + n_trials=hpo_spec.n_trials, + save_models=hpo_spec.save_models, + sampler=sampler, + test_models=hpo_spec.test_models, + optimization_space=optimization_space, + logger=logger, + seed=seed, + ) + logger.info("HPO complete") + return {"experiment_id": experiment_id, "finished_run_id": finished_run_id} + + + + +def _run_hpo( + args: Namespace, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + tasks: list, + task_names: list[str], + completed_task_run_names: list[str], + task_run_to_id_match: dict, + storage_uri: str, + experiment_name: str, + optimization_space: optimization_space_type, + n_trials: int, + logger: logging.RootLogger, + sampler: BaseSampler | RandomSampler, + description: str | None = None, + run_name: str | None = None, + run_id: str | None = None, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, +) -> tuple[str, str]: + """ + run HPO for multiple tasks under a single experiment. + + Args: + arg: Namespace, contains all parameters to be passed to model and datamodule. To be removed + model_spec: ModelSpec, contains all parameters to intiali model + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + tasks: list, + task_names: list[str], + completed_task_run_names: list[str], + task_run_to_id_match: dict, + storage_uri: str, + experiment_name: str, + optimization_space: optimization_space_type, + n_trials: int, + logger: logging.RootLogger, + sampler: BaseSampler | RandomSampler, + description: str | None = None, + run_name: str | None = None, + run_id: str | None = None, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, + + + """ + logger.info( + f"Running hyperparameter optimization: {run_name=} {run_id=}" + ) + if run_id is not None: + run_name = None + + with mlflow.start_run( + run_name=run_name, run_id=run_id, description=description + ) as run: + for task in tasks: + # only run task if it was not completed before + task_run_name = task.name + if task_run_name in completed_task_run_names: + logger.info(f"{task_run_name} already completed") + continue + else: + logger.info(f"{task_run_name} not completed. starting now") + + task_run_id = ( + task_run_to_id_match[task_run_name] + if task_run_name in task_run_to_id_match + else None + ) + best_value, metric_name, hparams = _run_hpo_per_task( + args=args, #TODO + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + logger=logger, + task=task, + storage_uri=str(storage_uri), + experiment_name=experiment_name, + experiment_run_id=run.info.run_id, + task_run_id=task_run_id, + optimization_space=optimization_space, + n_trials=n_trials, + save_models=save_models, + sampler=sampler, + test_models=test_models, + seed=seed, + ) + experiment_id = run.info.experiment_id + + # check completion of HPO for all tasks before proceeding to next stage + existing_experiments = check_existing_experiments( + logger=logger, + storage_uri=storage_uri, + experiment_name=experiment_name, + exp_parent_run_name=run_name, + task_names=task_names, + n_trials=n_trials, + ) + if existing_experiments["finished_run"] is not None: + finished_run_id = existing_experiments["finished_run"] + else: + logger.info("HPO is not complete. Please re-run this experiment") + raise RuntimeError + + return experiment_id, finished_run_id + + + + + + +def _run_hpo_per_task( + args: Namespace, #TODO: remove args + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + logger: logging.RootLogger, + task: TaskSpec, + storage_uri: str, + experiment_name: str, + experiment_run_id: str, + task_run_id: str | None = None, + optimization_space: optimization_space_type | None = None, + n_trials: int = 1, + save_models: bool = False, + sampler: BaseSampler | None = None, + test_models: bool = False, + seed: int = 42, +): + """ + Performs HPO on a single task + + Args: + args: Namespace, #TODO: remove args + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + logger: logging.RootLogger, + task: TaskSpec, + storage_uri: str, + experiment_name: str, + experiment_run_id: str, + task_run_id: str | None = None, + optimization_space: optimization_space_type | None = None, + n_trials: int = 1, + save_models: bool = False, + sampler: BaseSampler | None = None, + test_models: bool = False, + seed: int = 42, + + """ + logger.info( + f"starting backbone benchmark on task {task.name} {task_run_id=} {experiment_name=}" + ) + if storage_uri.startswith("http"): + optuna_db_path = Path(".") / "optuna_db" + else: + optuna_db_path = Path(storage_uri).parents[0] / "optuna_db" + + if not os.path.exists(optuna_db_path): + os.makedirs(optuna_db_path) + optuna_db_path = optuna_db_path / f"{experiment_name}_{experiment_run_id}" + optuna_db_path = str(optuna_db_path) + + task_run_id = sync_mlflow_optuna( + optuna_db_path=optuna_db_path, + storage_uri=storage_uri, + experiment_name=experiment_name, + task_run_id=task_run_id, + task=task, + n_trials=n_trials, + logger=logger, + ) + if task_run_id is not None: + # run_name is used only when run_id is unspecified. + run_name = None + else: + run_name = task.name + logger.info(f"start run: {run_name=} {task_run_id=}") + with mlflow.start_run(run_name=run_name, nested=True, run_id=task_run_id) as run: + logger.info(f"starting task run with id: {run.info.run_id}") + if training_spec.epochs is None: + raise Exception("Must specify epochs for training") + + # if no optimization params, just run it + if optimization_space is None: + return ( + *fit_model( + args=args, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=f"{run_name}_no_optim", + experiment_name=experiment_name, + storage_uri=storage_uri, + parent_run_id=run.info.run_id, + save_models=save_models, + test_models=test_models, + seed=seed, + logger=logger, + ), + ) + + # if optimization parameters specified, do hyperparameter tuning + study = optuna.create_study( + sampler=sampler, + direction=direction_type_to_optuna[ + task.direction + ], # in the future may want to allow user to specify this + pruner=HyperbandPruner(), + study_name=task.name, + storage="sqlite:///{}.db".format(optuna_db_path), + load_if_exists=True, + ) + + objective = partial( + fit_model_with_hparams, + args, + model_spec, + training_spec, + optimizer_spec, + callbacks_spec, + task, + optimization_space, + run_name, + experiment_name, + storage_uri, + run.info.run_id, + logger, + save_models, + test_models, + seed, + ) + + n_trials = n_trials - len(study.trials) + for trial in study.trials: + if (trial.state == optuna.trial.TrialState.FAIL) | ( + trial.state == optuna.trial.TrialState.RUNNING + ): + n_trials = n_trials + 1 + + study.optimize( + objective, + n_trials=n_trials, + catch=[torch.cuda.OutOfMemoryError], + ) + + tags = { + "seed": str(seed), + "n_trials": str(n_trials), + "model_spec": vars(model_spec), + "training_spec": vars(training_spec), + "optimizer_spec": vars(optimizer_spec), + "callbacks_spec": vars(callbacks_spec), + "data": vars(task.data), + } + mlflow.set_tags(tags) + + best_params = unflatten(study.best_trial.params) + mlflow.log_params(best_params) # unflatten + mlflow.log_metric(f"best_{task.metric}", study.best_value) + return study.best_value, task.metric, best_params + + + +def run_repeated_experiments( + args: Namespace, #TODO: remove args + logger: logging.RootLogger, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + hpo_spec: HyperParameterOptmizerSpec, + tasks: list[TaskSpec], + seed: int, + repeated_storage_uri: Path, + hpo_storage_uri: Path, + csv_folder: Path, + experiment_id: str, + parent_run_id: str, +): + """Repeat best experiments from a benchmark run. Only works with a ray cluster. + + Args: + + + """ + + + # if backbone_import: + # importlib.import_module(backbone_import) + + experiment_name = hpo_spec.experiment_name + num_repetitions = hpo_spec.num_repetitions + #find completed HPO tasks + mlflow.set_tracking_uri(str(hpo_storage_uri)) + mlflow.set_experiment(experiment_name) + + runs: list[mlflow.entities.Run] = mlflow.search_runs( + filter_string=f"tags.mlflow.parentRunId='{parent_run_id}'", output_format="list" + ) # type: ignore + logger.info(f"parent_run_id {parent_run_id}") + logger.info(f"Found runs: {[run.info.run_name for run in runs]}") + + task_names = [task.name for task in tasks] + logger.info(f"Will only run the following: {task_names}") + + table_columns = [ + "Task", + "Metric", + "Score", + "mlflow_run_name", + "mlflow_run_id", + "mlflow_run_status", + ] + table_entries = [] + + mlflow.set_tracking_uri(repeated_storage_uri) + mlflow.set_experiment(experiment_name) + experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id + output_path = csv_folder / f"{experiment_name}_repeated_exp_mlflow.csv" + if not os.path.isabs(output_path): + raise Exception( + f"output_path must be absolute." + ) + + # backbone_name = defaults.terratorch_task["model_args"]["backbone"] + with mlflow.start_run(run_name=experiment_name, run_id=None) as run: + for task in tasks: + logger.info(f"\n\ntask: {task.name}") + matching_runs = [ + run for run in runs if run.info.run_name.endswith(task.name) + ] # type: ignore + if len(matching_runs) == 0: + msg = f"No runs found for task {task.name}. Skipping." + warnings.warn(msg) + continue + if len(matching_runs) > 1: + msg = f"More than 1 run found for task {task.name}" + raise Exception(msg) + + # check if there are already results for this task and exp in the folder + if os.path.exists(output_path): + logger.info("there are previous results from repeated experiments") + existing_output = pd.read_csv(output_path, index_col=False) + existing_output = existing_output[table_columns] + existing_task_output = existing_output.loc[ + existing_output["Task"] == task.name + ].copy() + rows, cols = existing_task_output.shape + logger.info(f"rows: {rows} \t cols: {cols}") + if rows > num_repetitions: + logger.info("task has complete results, will not re-run") + continue + past_seeds = [ + int(item.split("_")[-1]) + for item in existing_task_output["mlflow_run_name"].tolist() + ] + else: + past_seeds = [] + logger.info(f"past_seeds for task: {past_seeds}") + + # get best parameters + best_params = matching_runs[0].data.params + best_params = {k: literal_eval(v) for k, v in best_params.items()} + + training_spec = combine_with_defaults(task, defaults) + lightning_task_class = training_spec.task.type.get_class_from_enum() + + experiment_info = mlflow.get_experiment_by_name(experiment_name) + seeds = [randint(1, 5000) for i in range(num_repetitions * 5)] + seeds = [seed for seed in seeds if seed not in past_seeds] + + for seed in seeds: + if len(past_seeds) >= num_repetitions: + break + + seed_run_name = f"{task.name}_{seed}" + logger.info(f"now trying: {seed_run_name}") + seed_run_data = mlflow.search_runs( + experiment_ids=[experiment_info.experiment_id], + filter_string=f'tags."mlflow.runName" LIKE "{seed_run_name}"', + output_format="list", + ) # type: ignore + if len(seed_run_data) > 0: + continue + + score = non_remote_fit( + experiment_name=repeated_experiment_name, + parent_run_id=run.info.run_id, + storage_uri=repeated_storage_uri, + task=task, + training_spec=training_spec, + lightning_task_class=lightning_task_class, + best_params=best_params, + seed=seed, + backbone_import=backbone_import, + save_models=save_models, + report_on_best_val=report_on_best_val, + ) + # check if run with name finished successfully + logger.info(f"score: {score}") + # TODO improve this sleep command - try to get a better estimate than this + time.sleep(60) + seed_run_data = mlflow.search_runs( + experiment_ids=[experiment_info.experiment_id], + filter_string=f'tags."mlflow.runName" LIKE "{seed_run_name}"', + output_format="list", + ) # type: ignore + + logger.info(f"run for task {task.name} seed {seed} complete") + if len(seed_run_data) > 0: + if seed_run_data[0].info.status != "FINISHED": + mlflow.delete_run(seed_run_data[0].info.run_id) + continue + past_seeds.append(seed) + new_data = pd.DataFrame( + { + "Task": [task.name], + "Metric": [task.metric.split("/")[-1]], + "Score": [score], + "mlflow_run_name": [seed_run_name], + "mlflow_run_id": [seed_run_data[0].info.run_id], + "mlflow_run_status": [seed_run_data[0].info.status], + } + ) + logger.info( + f"completed seeds so far for this task: {len(past_seeds)}" + ) + if os.path.exists(output_path): + logger.info( + "there are previous results from repeated experiments" + ) + + existing_output = pd.read_csv(output_path, index_col=False) + existing_output = existing_output[table_columns] + existing_output.reset_index(inplace=True) + existing_task_output = existing_output.loc[ + existing_output["Task"] == task.name + ].copy() + rows, cols = existing_task_output.shape + logger.info(f"rows: {rows} \t cols: {cols}") + if rows == 0: + logger.info("no past results for this task") + existing_output = pd.concat( + [existing_output, new_data], axis=0 + ) + existing_output.reset_index(inplace=True) + existing_output = existing_output.drop( + columns=["index", "level_0"] + ) + existing_output.to_csv(output_path, index=False) + else: + new_data.to_csv(output_path, index=False) + + +def _run_repeated_per_task( + args: Namespace, #TODO: remove args + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + logger: logging.RootLogger, + task: TaskSpec, + storage_uri: str, + experiment_name: str, + experiment_run_id: str, + task_run_id: str | None = None, + optimization_space: optimization_space_type | None = None, + n_trials: int = 1, + save_models: bool = False, + sampler: BaseSampler | None = None, + test_models: bool = False, + seed: int = 42, +): + """ + Performs HPO on a single task + + Args: + args: Namespace, #TODO: remove args + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + logger: logging.RootLogger, + task: TaskSpec, + storage_uri: str, + experiment_name: str, + experiment_run_id: str, + task_run_id: str | None = None, + optimization_space: optimization_space_type | None = None, + n_trials: int = 1, + save_models: bool = False, + sampler: BaseSampler | None = None, + test_models: bool = False, + seed: int = 42, + + """ + logger.info( + f"starting backbone benchmark on task {task.name} {task_run_id=} {experiment_name=}" + ) + if storage_uri.startswith("http"): + optuna_db_path = Path(".") / "optuna_db" + else: + optuna_db_path = Path(storage_uri).parents[0] / "optuna_db" + + if not os.path.exists(optuna_db_path): + os.makedirs(optuna_db_path) + optuna_db_path = optuna_db_path / f"{experiment_name}_{experiment_run_id}" + optuna_db_path = str(optuna_db_path) + + task_run_id = sync_mlflow_optuna( + optuna_db_path=optuna_db_path, + storage_uri=storage_uri, + experiment_name=experiment_name, + task_run_id=task_run_id, + task=task, + n_trials=n_trials, + logger=logger, + ) + if task_run_id is not None: + # run_name is used only when run_id is unspecified. + run_name = None + else: + run_name = task.name + logger.info(f"start run: {run_name=} {task_run_id=}") + with mlflow.start_run(run_name=run_name, nested=True, run_id=task_run_id) as run: + logger.info(f"starting task run with id: {run.info.run_id}") + if training_spec.epochs is None: + raise Exception("Must specify epochs for training") + + + # if no optimization params, just run it + if optimization_space is None: + return ( + *fit_model( + args=args, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=run.info.run_name, + experiment_name=experiment_name, + storage_uri=storage_uri, + parent_run_id=run.info.run_id, + save_models=save_models, + test_models=test_models, + seed=seed, + logger=logger, + ), + ) diff --git a/gridfm_graphkit/iterate/model_fitting.py b/gridfm_graphkit/iterate/model_fitting.py new file mode 100644 index 0000000..79f26a4 --- /dev/null +++ b/gridfm_graphkit/iterate/model_fitting.py @@ -0,0 +1,448 @@ +""" +This module contains all the logic for fitting models +""" + +import abc +import copy +import dataclasses +import importlib +import os +import shutil +import types +import uuid +import torch +import warnings +from abc import abstractmethod +from functools import wraps +from typing import Callable +import pandas as pd +import lightning.pytorch as pl +import mlflow +import optuna +from lightning import Callback, Trainer +from lightning.pytorch.callbacks import ( + ModelCheckpoint, + +) +from jsonargparse import Namespace + +import logging +from lightning.pytorch.loggers.mlflow import MLFlowLogger + +from optuna.integration import PyTorchLightningPruningCallback + + +from gridfm_graphkit.training.callbacks import get_training_callbacks +from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule + +from gridfm_graphkit.utils.types import ( + OptimizerSpec, + ModelSpec, + TrainingSpec, + TaskSpec, + CallbackSpec, + valid_task_types, + ParameterBounds, + ParameterTypeEnum, + optimization_space_type, + recursive_merge, + ) + +from gridfm_graphkit.tasks import ( + FeatureReconstructionTask, + # ContingencyAnalysisTask, +) + + + + +from gridfm_graphkit.iterate.utils import get_logger + +LOGGER = get_logger() + + +os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = ( + "1" # disable tune loggers, will add csv and json manually. If this is not here, it will log to tensorboard automatically +) + + +class ParameterPicker(abc.ABC): + @abstractmethod + def pick_categorical(self, variable, choices): + pass + + @abstractmethod + def pick_int(self, variable, low, high): + pass + + @abstractmethod + def pick_float(self, variable, low, high, log=False): + pass + + +class OptunaParameterPicker(ParameterPicker): + def __init__(self, trial: optuna.Trial): + super().__init__() + self.trial = trial + + def pick_categorical(self, variable, choices): + return self.trial.suggest_categorical(variable, choices) + + def pick_int(self, variable, low, high): + return self.trial.suggest_int(variable, low, high) + + def pick_float(self, variable, low, high, log=False): + return self.trial.suggest_float(variable, low, high, log=log) + + +def inject_hparams( + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + config: dict): + assert isinstance(config, dict), ( + f"Error! Unexpected config type: {config}" + ) + training_spec_with_hparams = copy.deepcopy(training_spec) + optimizer_spec_with_hparams = copy.deepcopy(optimizer_spec) + + recursive_merge(training_spec_with_hparams, config) + recursive_merge(optimizer_spec_with_hparams, config) + + return training_spec_with_hparams, optimizer_spec_with_hparams + + +def generate_parameters( + parameter_picker: ParameterPicker, + current_hparams: dict, + hparam_space: dict, + ignore_keys: set[str] | None = None, + dictionary_position: list[str] | None = None, +): + if ignore_keys is None: + ignore_keys = set() + if dictionary_position is None: + dictionary_position = [] + _generate_parameters( + parameter_picker, + current_hparams, + hparam_space, + ignore_keys, + dictionary_position, + ) + + +def _generate_parameters( + parameter_picker: ParameterPicker, + current_hparams: dict, + hparam_space: dict, + ignore_keys: set[str], + dictionary_position: list[str], +): + for parameter, space in hparam_space.items(): + if parameter in ignore_keys: + continue + # if its a dictionary, continue to recurse + if isinstance(space, dict): + if parameter not in current_hparams: + current_hparams[parameter] = {} + dictionary_position.append(parameter) + _generate_parameters( + parameter_picker, + current_hparams[parameter], + hparam_space[parameter], + ignore_keys, + dictionary_position, + ) + dictionary_position.pop() + # if not, get a value from the parameter_picker and insert it with the name prepended by the dictionary position + # this is important so that the full path of the parameter is used + # this will avoid confusion between parameters with the same name but from different components + else: + full_parameter_name = ".".join(dictionary_position + [parameter]) + if isinstance(space, list): + suggestion = parameter_picker.pick_categorical( + full_parameter_name, space + ) + current_hparams[parameter] = suggestion + elif isinstance(space, ParameterBounds): + match space.type: + case ParameterTypeEnum.integer: + current_hparams[parameter] = parameter_picker.pick_int( + full_parameter_name, + int(space.min), + int(space.max), + ) + case ParameterTypeEnum.real: + current_hparams[parameter] = parameter_picker.pick_float( + full_parameter_name, space.min, space.max, log=space.log + ) + case _: + raise Exception( + f"Type {space.type} not recognized. Suggest one of {[e.value for e in ParameterTypeEnum]}" + ) + else: + raise Exception( + "Leaves of optimization space must be lists or ParameterBounds" + ) + + + + + + + + +""" +single node - optuna +""" +def launch_training( + trainer: Trainer, + model: FeatureReconstructionTask, #TODO: create basetask in tasks folder + optimizer_spec: OptimizerSpec, + datamodule: LitGridDataModule, + run_name: str, + experiment_name: str, + metric: str, + storage_uri: str, + parent_run_id: str, + direction: str, + test_models: bool, + delete_models_after_testing: bool, +) -> float: + + with mlflow.start_run(run_name=run_name, nested=True) as run: + mlflow.set_tag("mlflow.parentRunId", parent_run_id) + # explicitly log batch_size. Since it is not a model param, it will not be logged + mlflow.log_param("batch_size", datamodule.batch_size) + + trainer.logger = MLFlowLogger( + experiment_name=experiment_name, + run_id=run.info.run_id, + save_dir=storage_uri, + log_model=not delete_models_after_testing, + ) + trainer.fit(model, datamodule=datamodule) + if test_models: + test_metrics = trainer.test( + model, + ckpt_path="best", + datamodule=datamodule) + test_metrics =test_metrics[0] + if delete_models_after_testing: + # delete the checkpoints folder in the run + ckpts_folder = os.path.join( + trainer.logger.save_dir, + str(trainer.logger.name), + trainer.logger.version, + "checkpoints", + ) + shutil.rmtree(ckpts_folder) + + client = mlflow.tracking.MlflowClient( + tracking_uri=storage_uri, + ) + + if not metric.lower().startswith("val"): + raise Exception( + f"Metric {metric} does not start with `val`. Please choose a validation metric" + ) + for_pd_collect = [] + val_metrics_names = [] + + print(f'{client.get_run(run.info.run_id)=}') + + for cname in client.get_run(run.info.run_id).data.metrics: + print(f'{cname=}') + + for metric_name in client.get_run(run.info.run_id).data.metrics: + if metric_name.lower().startswith("val"): + val_metrics_names.append(metric_name) + val_metric_history = client.get_metric_history( + run.info.run_id, metric_name + ) + pd_convertible_metric_history = [ + { + "metric_name": mm.key, + "step": mm.step, + "value": mm.value, + } + for mm in val_metric_history + ] + for_pd_collect += pd_convertible_metric_history + df_val_metrics = pd.DataFrame.from_records(for_pd_collect) + df_val_metrics = df_val_metrics.set_index( + ["metric_name", "step"], verify_integrity=True + ) + series_val_metrics = df_val_metrics["value"] + assert metric in series_val_metrics, ( + f"Error! {metric} is not in {series_val_metrics}" + ) + if direction == "max": + best_step = series_val_metrics[metric].idxmax() + elif direction == "min": + best_step = series_val_metrics[metric].idxmin() + else: + raise Exception( + f"Error! Direction must be either `max` or `min` but got {direction}" + ) + + for val_metric_name in val_metrics_names: + mlflow.log_metric( + f"best_step_{val_metric_name}", + series_val_metrics[(val_metric_name, best_step)], + ) + + return series_val_metrics[(metric, best_step)] + + +def fit_model( + args: Namespace, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + task: TaskSpec, + run_name: str, + experiment_name: str, + storage_uri: str, + parent_run_id: str, + logger: logging.RootLogger, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, + trial: optuna.Trial | None = None, +) -> tuple[float, str]: + pl.seed_everything(seed, workers=True) + training_spec_copy = copy.deepcopy(training_spec) + + #get callbacks + callbacks: list[Callback] = get_training_callbacks(callbacks_spec) + if callbacks_spec.optuna_early_prune and trial is not None: + callbacks.append( + PyTorchLightningPruningCallback(trial, monitor="Validation loss") + ) + if len(callbacks) > 0: + warnings.warn( + "Callbacks passed to trainer. Make sure these are stateless, as they will not be reinitialized for each task!" + ) + + delete_models_after_testing = False + if test_models and not save_models: + # we need to save the models during training to be able to test but can be deleted afterwards + save_models = True + delete_models_after_testing = True + if save_models: + callbacks.append( + ModelCheckpoint(monitor=task.metric, mode=task.direction) + ) + enable_checkpointing = False + if any([isinstance(cb, ModelCheckpoint) for cb in callbacks]): + enable_checkpointing=True + + # # initialize datamodule + args.data = task.data + datamodule = LitGridDataModule(args, task.data.data_path) + + #initialize model + lightning_task_class: valid_task_types = task.type.get_class_from_enum() + model = lightning_task_class( + args,#TODO: load model, training, optim separataly + datamodule.node_normalizers, + datamodule.edge_normalizers, + ) + logger.info(f"Loading model weights from {model_spec.model_path}") + state_dict = torch.load(model_spec.model_path) + model.load_state_dict(state_dict) + + # initialize trainer + trainer = Trainer( + accelerator=training_spec_copy.accelerator, + devices=training_spec_copy.devices, + strategy=training_spec_copy.strategy, + log_every_n_steps=training_spec_copy.log_every_n_steps, + # default_root_dir=args.log_dir, + max_epochs=training_spec_copy.epochs, + callbacks=callbacks, + enable_checkpointing=enable_checkpointing, + enable_progress_bar=training_spec_copy.enable_progress_bar, + ) + + logger.info( + f"launch_training {trainer=} {lightning_task_class=} {datamodule=} \ + {run_name=} {experiment_name=} {task.metric=} {storage_uri=} {task.direction=}" + ) + return ( + launch_training( + trainer=trainer, + model=model, + optimizer_spec=optimizer_spec, + datamodule=datamodule, + run_name=run_name, + experiment_name=experiment_name, + metric=task.metric, + storage_uri=storage_uri, + parent_run_id=parent_run_id, + direction=task.direction, + test_models=test_models, + delete_models_after_testing=delete_models_after_testing, + ) + ) + + +def fit_model_with_hparams( + args: Namespace, + model_spec: ModelSpec, + training_spec: TrainingSpec, + optimizer_spec: OptimizerSpec, + callbacks_spec: CallbackSpec, + task: TaskSpec, + optimization_space: optimization_space_type, + run_name: str, + experiment_name: str, + storage_uri: str, + parent_run_id: str, + logger: logging.RootLogger, + save_models: bool = False, + test_models: bool = False, + seed: int = 42, + trial: optuna.Trial | None = None, +) -> float: + """ + Generate parameters using the optuna trial from the given parameters. + Then inject these into the given task. + It is important to make sure to not overwrite the task passed in the arguments, or these updates may affect + subsequent trials. + """ + current_hparams: dict[str, int | float | str | bool] = {} + generate_parameters( + OptunaParameterPicker(trial), + current_hparams, + optimization_space, + ) + + training_spec_with_hparams, optimizer_spec_with_hparams = inject_hparams( + training_spec, optimizer_spec, current_hparams + ) + + output = fit_model( + args=args, + model_spec=model_spec, + training_spec=training_spec, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=f"{run_name}_{trial.number}", + experiment_name=experiment_name, + storage_uri=storage_uri, + parent_run_id=parent_run_id, + save_models=save_models, + test_models=test_models, + seed=seed, + logger=logger, + trial=trial, + ) # return only the metric value for optuna + + print(f'{output=}') + + return output + diff --git a/gridfm_graphkit/iterate/utils.py b/gridfm_graphkit/iterate/utils.py new file mode 100644 index 0000000..833b2eb --- /dev/null +++ b/gridfm_graphkit/iterate/utils.py @@ -0,0 +1,441 @@ +import os +from typing import Any, Dict + +import mlflow +import optuna +import logging +import datetime +from mlflow.entities.experiment import Experiment +from gridfm_graphkit.utils.types import ( + optimization_space_type, + TaskSpec, + + ParameterBounds, + HyperParameterOptmizerSpec, + CallbackSpec, + OptimizerSpec, + ModelSpec, + TrainingSpec, + DataSpec, + direction_type_to_optuna + ) + + + +# Custom function to parse the optimization space argument +def parse_optimization_space(space: dict | None) -> optimization_space_type | None: + if space is None: + return None + parsed_space: optimization_space_type = {} + for key, value in space.items(): + if isinstance(value, dict): + try: + bounds = ParameterBounds(**value) + parsed_space[key] = bounds + except TypeError: + # Recursively parse nested optimization spaces + parsed_space[key] = parse_optimization_space(value) + elif isinstance(value, list): + # If it's a list, leave it as is + parsed_space[key] = value + else: + raise ValueError(f"Invalid type for {key}: {value}") + return parsed_space + + +def unflatten(dictionary: Dict[str, Any]): + resultDict: Dict = {} + for key, value in dictionary.items(): + parts = key.split(".") + d = resultDict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return resultDict + + +def get_logger(log_level="INFO", log_folder="./experiment_logs") -> logging.RootLogger: + # set up logging file + if not os.path.exists(log_folder): + os.makedirs(log_folder) + current_time = datetime.datetime.now() + current_time = ( + str(current_time).replace(" ", "_").replace(":", "-").replace(".", "-") + ) + log_file = f"{log_folder}/{current_time}" + logger = logging.getLogger() + logger.setLevel(log_level) + handler = logging.FileHandler(log_file) + handler.setLevel(log_level) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logging.basicConfig(level=logging.CRITICAL) + return logger + + +def check_existing_experiments( + logger: logging.RootLogger, + storage_uri: str, + experiment_name: str, + exp_parent_run_name: str, + task_names: list, + n_trials: int, +) -> Dict[str, Any]: + """ + checks if experiment has been completed (i.e. both task run and nested individual runs are complete) + Args: + logger: logging.RootLogger to save logs to file + storage_uri: folder containing mlflow log data + experiment_name: name of experiment + exp_parent_run_name: run name of the top level experiment run + task_names: list of task names that should be completed + n_trials: number of trials (runs) expected in HPO of each task + Returns: + output: dict with: + no_existing_runs: bool, if True, there are no existing runs + incomplete_run_to_finish: str | None, run id of the experiment run to finish + finished_run: str | None, run id of the finished experiment run + experiment_id: str | None, experiment id it experiment already exists + + """ + client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) + experiment_info = client.get_experiment_by_name(experiment_name) + + output = { + "no_existing_runs": True, + "incomplete_run_to_finish": None, + "finished_run": None, + "experiment_id": None, + } + if experiment_info is None: + return output + + experiment_id = experiment_info.experiment_id + logger.info(f"\nexperiment_id: {experiment_id}") + logger.info(f"experiment_name: {experiment_name}") + output["experiment_id"] = experiment_id + experiment_parent_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.runName" LIKE "{exp_parent_run_name}"', + ) + if len(experiment_parent_run_data) >= 1: + logger.info("there is at least one experiment parent run") + finished_run_id = None + incomplete_runs = [] + + # check if one of the runs is complete + for run in experiment_parent_run_data: + ( + completed_task_run_names, + all_tasks_in_experiment_finished, + _, + ) = check_existing_task_parent_runs( + logger=logger, + exp_parent_run_id=run.info.run_id, + storage_uri=storage_uri, + experiment_name=experiment_name, + n_trials=n_trials, + ) + logger.info(f"tasks that should be completed: {task_names}") + logger.info(f"completed_task_run_names: {completed_task_run_names}") + logger.info( + f"all_tasks_in_experiment_finished: {all_tasks_in_experiment_finished}" + ) + all_expected_tasks_completed = [ + item for item in task_names if item in completed_task_run_names + ] + all_expected_tasks_completed = len(task_names) == len( + all_expected_tasks_completed + ) + if all_expected_tasks_completed: + finished_run_id = run.info.run_id + logger.info( + f"The following run FINISHED and will be used for repeated experiments: {finished_run_id}" + ) + else: + incomplete_tasks = [ + item for item in task_names if item not in completed_task_run_names + ] + logger.info( + f"The following run {run.info.run_id} is incomplete, with status {run.info.status} and missing tasks: {incomplete_tasks}" + ) + incomplete_runs.append(run.info.run_id) + + if finished_run_id is not None: + # delete all incomplete runs + delete_nested_experiment_parent_runs( + logger=logger, + delete_runs=incomplete_runs, + experiment_info=experiment_info, + client=client, + leave_one=False, + ) + output["finished_run"] = finished_run_id + output["no_existing_runs"] = False + else: + # delete all incomplete runs, leave one + logger.info(f"incomplete_runs: {incomplete_runs}") + output["incomplete_run_to_finish"] = delete_nested_experiment_parent_runs( + logger=logger, + delete_runs=incomplete_runs, + experiment_info=experiment_info, + client=client, + leave_one=True, + ) + output["no_existing_runs"] = False + return output + + + + +def delete_nested_experiment_parent_runs( + logger: logging.RootLogger, + delete_runs: list, + experiment_info: mlflow.entities.experiment.Experiment, + client: mlflow.tracking.client.MlflowClient, + leave_one: bool = True, +) -> str | None: + """ + if there are multiple runs for a single experiment, + will delete all runs except the one with the most nested runs (most complete) + Args: + logger: logging.RootLogger to save logs to file + delete_runs: list of runs to delete + experiment_info: info of experiment + client: mlflow client pointing to correct storage uri + leave_one: if True, will not delete the most complete experiment. If False, will delete all experiments + Returns: + run id of the experiment run that was not deleted or None + """ + experiment_id = experiment_info.experiment_id + exp_parent_run_ids = [] + counts = [] + runs_in_experiment = [] + logger.info(f"Deleting from experiment_id:{experiment_id} ") + logger.info(f"delete_runs:{delete_runs} ") + + for exp_parent_run_id in delete_runs: + runs = [] + runs.append(exp_parent_run_id) + task_parent_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{exp_parent_run_id}"', + ) + for task_parent_run in task_parent_run_data: + task_parent_run_id = task_parent_run.info.run_id + runs.append(task_parent_run_id) + individual_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{task_parent_run_id}"', + ) + for individual_run in individual_run_data: + runs.append(individual_run.info.run_id) + exp_parent_run_ids.append(exp_parent_run_id) + counts.append(len(runs)) + runs_in_experiment.append(runs) + + if leave_one and (len(counts) > 0): + index_to_keep = counts.index(max(counts)) + incomplete_run_to_finish = exp_parent_run_ids[index_to_keep] + runs_in_experiment.pop(index_to_keep) + else: + incomplete_run_to_finish = None + + logger.info(f"Deleting runs:{runs_in_experiment} ") + logger.info( + f"experiment_info.artifact_location:{experiment_info.artifact_location}" + ) + for runs in runs_in_experiment: + for run_id in runs: + client.delete_run(run_id) + os.system(f"rm -r {experiment_info.artifact_location}/{run_id}") + return incomplete_run_to_finish + + + +def check_existing_task_parent_runs( + logger: logging.RootLogger, + exp_parent_run_id: str, + storage_uri: str, + experiment_name: str, + n_trials: int = 5, +): + """ + checks if tasks have been completed (both task run and nested individual runs are complete) + Args: + logger: logging.RootLogger to save logs to file + exp_parent_run_id: run id of the experiment run being used (top level run id) + storage_uri: folder containing mlflow log data + experiment_name: name of experiment + n_trials: number of trials (runs) expected in HPO of each task + Returns: + complete_task_run_names: list of task names that have been completed + all_tasks_finished: bool showing if all tasks have been completed + task_run_to_id_match: dict matching task names to the task run id + + """ + client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) + experiment_info = client.get_experiment_by_name(experiment_name) + experiment_id = experiment_info.experiment_id + task_parent_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{exp_parent_run_id}"', + ) + complete_task_run_names = [] + all_tasks_finished = [] + # TO DO: make sure we only have one task_parent_run for each name (needed for repeated exps) + task_run_to_id_match = {} + for task_parent_run in task_parent_run_data: + task_run_statuses = [] + task_run_ids = [] + task_run_statuses.append(task_parent_run.info.status) + task_run_ids.append(task_parent_run.info.run_id) + + individual_run_data = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{task_parent_run.info.run_id}"', + ) + for individual_run in individual_run_data: + if (individual_run.info.status == "RUNNING") or ( + individual_run.info.status == "FAILED" + ): + continue + task_run_statuses.append(individual_run.info.status) + task_run_ids.append(individual_run.info.run_id) + + task_run_to_id_match[task_parent_run.info.run_name] = ( + task_parent_run.info.run_id + ) + task_run_statuses = list(set(task_run_statuses)) + + condition_1 = len(task_run_statuses) == 1 + condition_2 = task_run_statuses[0] == "FINISHED" + # condition_3 = len(task_run_ids) == (n_trials+1) + if condition_1 and condition_2: # and condition_3: + complete_task_run_names.append(task_parent_run.info.run_name) + task_parent_status = True + else: + task_parent_status = False + all_tasks_finished.append(task_parent_status) + + if all(all_tasks_finished) and (len(all_tasks_finished) > 0): + all_tasks_finished = True + else: + all_tasks_finished = False + complete_task_run_names = list(set(complete_task_run_names)) + return complete_task_run_names, all_tasks_finished, task_run_to_id_match + + + + + + +def sync_mlflow_optuna( + optuna_db_path: str, + storage_uri: str, + experiment_name: str, + task_run_id: str | None, + task: TaskSpec, + n_trials: int, + logger: logging.RootLogger, +) -> str | None: + """ + syncs the number of completed trials in mflow and optuna + Args: + optuna_db_path: path to optuna database + storage_uri: path to mlflow storage folder + experiment_name: name on experiment in mlflow + task_run_id: run_id of the task + task: name of the task + logger: logging.RootLogger to save logs to file + Returns: + task_run_id: run id of the task to be continued (if one exists) or None + """ + logger.info( + f"sync_mlflow_optuna - {optuna_db_path=} {storage_uri=} {task_run_id=} {experiment_name=} {task_run_id=}" + ) + # check number of successful mlflow runs in task + client = mlflow.tracking.MlflowClient(tracking_uri=storage_uri) + completed_in_mlflow_for_task = [] + all_mlflow_runs_for_task = [] + if task_run_id is not None: + all_mlflow_runs_for_task.append(task_run_id) + logger.info(f"sync_mlflow_optuna - {task_run_id=}") + experiment_info = client.get_experiment_by_name(experiment_name) + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) + individual_run_data = client.search_runs( + experiment_ids=[experiment_info.experiment_id], + filter_string=f'tags."mlflow.parentRunId" LIKE "{task_run_id}"', + ) + for individual_run in individual_run_data: + if individual_run.info.status == "FINISHED": + completed_in_mlflow_for_task.append(individual_run.info.run_id) + all_mlflow_runs_for_task.append(individual_run.info.run_id) + + # check number of successful optuna trials in the database + study_names = optuna.study.get_all_study_names( + storage="sqlite:///{}.db".format(optuna_db_path) + ) + if task.name in study_names: + loaded_study = optuna.load_study( + study_name=task.name, storage="sqlite:///{}.db".format(optuna_db_path) + ) + logger.info(f"loaded_study has : {len(loaded_study.trials)} trials") + incomplete = 0 + for trial in loaded_study.trials: + if (trial.state == optuna.trial.TrialState.FAIL) | ( + trial.state == optuna.trial.TrialState.RUNNING + ): + incomplete += 1 + logger.info(f"{incomplete} trials are incomplete") + successful_optuna_trials = len(loaded_study.trials) - incomplete + too_many_trials = successful_optuna_trials > n_trials + no_existing_task = task_run_id is None + optuna_mlflow_mismatch = ( + len(completed_in_mlflow_for_task) != successful_optuna_trials + ) + logger.info( + f"successful optuna trials {successful_optuna_trials} . mlflow runs {len(completed_in_mlflow_for_task)}" + ) + + if too_many_trials or no_existing_task or optuna_mlflow_mismatch: + logger.info(f"deleting study with name {task.name}") + logger.info(f"too_many_trials {too_many_trials}") + logger.info(f"no_existing_task {no_existing_task}") + + # delete optuna study in database + optuna.delete_study( + study_name=task.name, storage="sqlite:///{}.db".format(optuna_db_path) + ) + + # delete any existing mlflow runs + if len(all_mlflow_runs_for_task) > 0: + for item in all_mlflow_runs_for_task: + logger.info(f"deleting {item}") + client.delete_run(item) + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) + os.system(f"rm -r {experiment_info.artifact_location}/{item}") + task_run_id = None + else: + # delete any existing mlflow runs + if len(all_mlflow_runs_for_task) > 0: + for item in all_mlflow_runs_for_task: + logger.info(f"deleting {item}") + client.delete_run(item) + assert isinstance(experiment_info, Experiment), ( + f"Error! Unexpected type of {experiment_info=}" + ) + os.system(f"rm -r {experiment_info.artifact_location}/{item}") + task_run_id = None + logging.info(f"sync_mlflow_optuna returns {task_run_id=}") + return task_run_id + From cfc4c47162d1abf9580f4368961363bfcc86c639 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 15:39:57 -0500 Subject: [PATCH 10/16] formatting Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/cli.py | 27 ++-- gridfm_graphkit/iterate/__init__.py | 14 +- gridfm_graphkit/iterate/model_fitting.py | 160 ++++++----------------- gridfm_graphkit/iterate/utils.py | 75 ++++++++--- gridfm_graphkit/utils/types.py | 65 ++++----- 5 files changed, 135 insertions(+), 206 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a9669d7..91bb93b 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -1,18 +1,24 @@ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule from gridfm_graphkit.io.param_handler import NestedNamespace from gridfm_graphkit.training.callbacks import get_training_callbacks +from gridfm_graphkit.iterate import run_iterate_experiments import numpy as np import os import yaml import torch import random import pandas as pd - from gridfm_graphkit.tasks import FeatureReconstructionTask - from lightning.pytorch.loggers import MLFlowLogger import lightning as L +from jsonargparse import Namespace + +DEFAULT_SEED = 42 + + + + def main_cli(args): @@ -103,23 +109,6 @@ def main_cli(args): -from gridfm_graphkit.utils.types import ( - HyperParameterOptmizerSpec, TaskSpec, CallbackSpec, - OptimizerSpec, ModelSpec, TrainingSpec, DataSpec - ) - -from gridfm_graphkit.iterate import run_iterate_experiments - -from jsonargparse import Namespace - - -DEFAULT_SEED = 42 - - - - - - def iterate_cli(config_args): #validate inputs diff --git a/gridfm_graphkit/iterate/__init__.py b/gridfm_graphkit/iterate/__init__.py index cf52487..9599e06 100644 --- a/gridfm_graphkit/iterate/__init__.py +++ b/gridfm_graphkit/iterate/__init__.py @@ -1,13 +1,9 @@ from gridfm_graphkit.iterate.hpo import ( - run_hpo_experiments, - run_repeated_experiments, - run_iterate_experiments) + run_hpo_experiments, + run_repeated_experiments, + run_iterate_experiments, +) # from gridfm_graphkit.tasks.contingency_analysis import ContingencyAnalysisTask - -__all__ = ( - "run_hpo_experiments", - "run_repeated_experiments", - "run_iterate_experiments" -) \ No newline at end of file +__all__ = ("run_hpo_experiments", "run_repeated_experiments", "run_iterate_experiments") diff --git a/gridfm_graphkit/iterate/model_fitting.py b/gridfm_graphkit/iterate/model_fitting.py index 79f26a4..1a1241b 100644 --- a/gridfm_graphkit/iterate/model_fitting.py +++ b/gridfm_graphkit/iterate/model_fitting.py @@ -4,25 +4,17 @@ import abc import copy -import dataclasses -import importlib import os import shutil -import types -import uuid import torch import warnings from abc import abstractmethod -from functools import wraps -from typing import Callable -import pandas as pd import lightning.pytorch as pl import mlflow import optuna from lightning import Callback, Trainer from lightning.pytorch.callbacks import ( ModelCheckpoint, - ) from jsonargparse import Namespace @@ -36,9 +28,9 @@ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule from gridfm_graphkit.utils.types import ( - OptimizerSpec, - ModelSpec, - TrainingSpec, + OptimizerSpec, + ModelSpec, + TrainingSpec, TaskSpec, CallbackSpec, valid_task_types, @@ -46,7 +38,7 @@ ParameterTypeEnum, optimization_space_type, recursive_merge, - ) +) from gridfm_graphkit.tasks import ( FeatureReconstructionTask, @@ -54,9 +46,7 @@ ) - - -from gridfm_graphkit.iterate.utils import get_logger +from gridfm_graphkit.iterate.utils import get_logger, get_best_val LOGGER = get_logger() @@ -96,12 +86,9 @@ def pick_float(self, variable, low, high, log=False): def inject_hparams( - training_spec: TrainingSpec, - optimizer_spec: OptimizerSpec, - config: dict): - assert isinstance(config, dict), ( - f"Error! Unexpected config type: {config}" - ) + training_spec: TrainingSpec, optimizer_spec: OptimizerSpec, config: dict +): + assert isinstance(config, dict), f"Error! Unexpected config type: {config}" training_spec_with_hparams = copy.deepcopy(training_spec) optimizer_spec_with_hparams = copy.deepcopy(optimizer_spec) @@ -186,18 +173,14 @@ def _generate_parameters( ) - - - - - - """ single node - optuna """ + + def launch_training( trainer: Trainer, - model: FeatureReconstructionTask, #TODO: create basetask in tasks folder + model: FeatureReconstructionTask, # TODO: create basetask in tasks folder optimizer_spec: OptimizerSpec, datamodule: LitGridDataModule, run_name: str, @@ -209,7 +192,6 @@ def launch_training( test_models: bool, delete_models_after_testing: bool, ) -> float: - with mlflow.start_run(run_name=run_name, nested=True) as run: mlflow.set_tag("mlflow.parentRunId", parent_run_id) # explicitly log batch_size. Since it is not a model param, it will not be logged @@ -222,12 +204,17 @@ def launch_training( log_model=not delete_models_after_testing, ) trainer.fit(model, datamodule=datamodule) + + output = get_best_val( + storage_uri=storage_uri, + run=run, + metric=metric, + direction=direction, + ) + if test_models: - test_metrics = trainer.test( - model, - ckpt_path="best", - datamodule=datamodule) - test_metrics =test_metrics[0] + test_metrics = trainer.test(model, ckpt_path="best", datamodule=datamodule) + output = test_metrics if delete_models_after_testing: # delete the checkpoints folder in the run ckpts_folder = os.path.join( @@ -238,61 +225,7 @@ def launch_training( ) shutil.rmtree(ckpts_folder) - client = mlflow.tracking.MlflowClient( - tracking_uri=storage_uri, - ) - - if not metric.lower().startswith("val"): - raise Exception( - f"Metric {metric} does not start with `val`. Please choose a validation metric" - ) - for_pd_collect = [] - val_metrics_names = [] - - print(f'{client.get_run(run.info.run_id)=}') - - for cname in client.get_run(run.info.run_id).data.metrics: - print(f'{cname=}') - - for metric_name in client.get_run(run.info.run_id).data.metrics: - if metric_name.lower().startswith("val"): - val_metrics_names.append(metric_name) - val_metric_history = client.get_metric_history( - run.info.run_id, metric_name - ) - pd_convertible_metric_history = [ - { - "metric_name": mm.key, - "step": mm.step, - "value": mm.value, - } - for mm in val_metric_history - ] - for_pd_collect += pd_convertible_metric_history - df_val_metrics = pd.DataFrame.from_records(for_pd_collect) - df_val_metrics = df_val_metrics.set_index( - ["metric_name", "step"], verify_integrity=True - ) - series_val_metrics = df_val_metrics["value"] - assert metric in series_val_metrics, ( - f"Error! {metric} is not in {series_val_metrics}" - ) - if direction == "max": - best_step = series_val_metrics[metric].idxmax() - elif direction == "min": - best_step = series_val_metrics[metric].idxmin() - else: - raise Exception( - f"Error! Direction must be either `max` or `min` but got {direction}" - ) - - for val_metric_name in val_metrics_names: - mlflow.log_metric( - f"best_step_{val_metric_name}", - series_val_metrics[(val_metric_name, best_step)], - ) - - return series_val_metrics[(metric, best_step)] + return output def fit_model( @@ -311,11 +244,11 @@ def fit_model( test_models: bool = False, seed: int = 42, trial: optuna.Trial | None = None, -) -> tuple[float, str]: +) -> dict: pl.seed_everything(seed, workers=True) training_spec_copy = copy.deepcopy(training_spec) - #get callbacks + # get callbacks callbacks: list[Callback] = get_training_callbacks(callbacks_spec) if callbacks_spec.optuna_early_prune and trial is not None: callbacks.append( @@ -332,21 +265,19 @@ def fit_model( save_models = True delete_models_after_testing = True if save_models: - callbacks.append( - ModelCheckpoint(monitor=task.metric, mode=task.direction) - ) + callbacks.append(ModelCheckpoint(monitor=task.metric, mode=task.direction)) enable_checkpointing = False if any([isinstance(cb, ModelCheckpoint) for cb in callbacks]): - enable_checkpointing=True + enable_checkpointing = True - # # initialize datamodule + # # initialize datamodule args.data = task.data datamodule = LitGridDataModule(args, task.data.data_path) - #initialize model + # initialize model lightning_task_class: valid_task_types = task.type.get_class_from_enum() model = lightning_task_class( - args,#TODO: load model, training, optim separataly + args, # TODO: load model, training, optim separataly datamodule.node_normalizers, datamodule.edge_normalizers, ) @@ -371,21 +302,19 @@ def fit_model( f"launch_training {trainer=} {lightning_task_class=} {datamodule=} \ {run_name=} {experiment_name=} {task.metric=} {storage_uri=} {task.direction=}" ) - return ( - launch_training( - trainer=trainer, - model=model, - optimizer_spec=optimizer_spec, - datamodule=datamodule, - run_name=run_name, - experiment_name=experiment_name, - metric=task.metric, - storage_uri=storage_uri, - parent_run_id=parent_run_id, - direction=task.direction, - test_models=test_models, - delete_models_after_testing=delete_models_after_testing, - ) + return launch_training( + trainer=trainer, + model=model, + optimizer_spec=optimizer_spec, + datamodule=datamodule, + run_name=run_name, + experiment_name=experiment_name, + metric=task.metric, + storage_uri=storage_uri, + parent_run_id=parent_run_id, + direction=task.direction, + test_models=test_models, + delete_models_after_testing=delete_models_after_testing, ) @@ -424,7 +353,7 @@ def fit_model_with_hparams( training_spec, optimizer_spec, current_hparams ) - output = fit_model( + output = fit_model( args=args, model_spec=model_spec, training_spec=training_spec, @@ -442,7 +371,4 @@ def fit_model_with_hparams( trial=trial, ) # return only the metric value for optuna - print(f'{output=}') - return output - diff --git a/gridfm_graphkit/iterate/utils.py b/gridfm_graphkit/iterate/utils.py index 833b2eb..1e8a7cc 100644 --- a/gridfm_graphkit/iterate/utils.py +++ b/gridfm_graphkit/iterate/utils.py @@ -5,21 +5,13 @@ import optuna import logging import datetime +import pandas as pd from mlflow.entities.experiment import Experiment from gridfm_graphkit.utils.types import ( optimization_space_type, TaskSpec, - ParameterBounds, - HyperParameterOptmizerSpec, - CallbackSpec, - OptimizerSpec, - ModelSpec, - TrainingSpec, - DataSpec, - direction_type_to_optuna - ) - +) # Custom function to parse the optimization space argument @@ -78,6 +70,61 @@ def get_logger(log_level="INFO", log_folder="./experiment_logs") -> logging.Root return logger +def get_best_val( + storage_uri: str, + run: mlflow.entities.Run, + metric: str, + direction: str, +): + client = mlflow.tracking.MlflowClient( + tracking_uri=storage_uri, + ) + + if not metric.lower().startswith("val"): + raise Exception( + f"Metric {metric} does not start with `val`. Please choose a validation metric" + ) + for_pd_collect = [] + val_metrics_names = [] + + for metric_name in client.get_run(run.info.run_id).data.metrics: + if metric_name.lower().startswith("val"): + val_metrics_names.append(metric_name) + val_metric_history = client.get_metric_history(run.info.run_id, metric_name) + pd_convertible_metric_history = [ + { + "metric_name": mm.key, + "step": mm.step, + "value": mm.value, + } + for mm in val_metric_history + ] + for_pd_collect += pd_convertible_metric_history + df_val_metrics = pd.DataFrame.from_records(for_pd_collect) + df_val_metrics = df_val_metrics.set_index( + ["metric_name", "step"], verify_integrity=True + ) + series_val_metrics = df_val_metrics["value"] + assert metric in series_val_metrics, ( + f"Error! {metric} is not in {series_val_metrics}" + ) + if direction == "max": + best_step = series_val_metrics[metric].idxmax() + elif direction == "min": + best_step = series_val_metrics[metric].idxmin() + else: + raise Exception( + f"Error! Direction must be either `max` or `min` but got {direction}" + ) + + for val_metric_name in val_metrics_names: + mlflow.log_metric( + f"best_step_{val_metric_name}", + series_val_metrics[(val_metric_name, best_step)], + ) + return series_val_metrics[(metric, best_step)] + + def check_existing_experiments( logger: logging.RootLogger, storage_uri: str, @@ -191,8 +238,6 @@ def check_existing_experiments( return output - - def delete_nested_experiment_parent_runs( logger: logging.RootLogger, delete_runs: list, @@ -257,7 +302,6 @@ def delete_nested_experiment_parent_runs( return incomplete_run_to_finish - def check_existing_task_parent_runs( logger: logging.RootLogger, exp_parent_run_id: str, @@ -331,10 +375,6 @@ def check_existing_task_parent_runs( return complete_task_run_names, all_tasks_finished, task_run_to_id_match - - - - def sync_mlflow_optuna( optuna_db_path: str, storage_uri: str, @@ -438,4 +478,3 @@ def sync_mlflow_optuna( task_run_id = None logging.info(f"sync_mlflow_optuna returns {task_run_id=}") return task_run_id - diff --git a/gridfm_graphkit/utils/types.py b/gridfm_graphkit/utils/types.py index 5a29b49..11a119d 100644 --- a/gridfm_graphkit/utils/types.py +++ b/gridfm_graphkit/utils/types.py @@ -2,31 +2,20 @@ This module defines all the types expected at input. Used for type checking by jsonargparse. """ -from ast import Dict from typing import Literal -import copy import enum -from dataclasses import dataclass, field, replace -from typing import Any, Optional, Union +from dataclasses import dataclass, field +from typing import Any, Union from gridfm_graphkit.tasks import ( FeatureReconstructionTask, - # ContingencyAnalysisTask, ) -from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule -import logging - - -valid_task_types = type[ - FeatureReconstructionTask - # | ContingencyAnalysisTask -] +valid_task_types = type[FeatureReconstructionTask] direction_type_to_optuna = {"min": "minimize", "max": "maximize"} - @dataclass class TaskTypeEnum(enum.Enum): """ @@ -34,7 +23,6 @@ class TaskTypeEnum(enum.Enum): """ feature_reconstruction = "feature_reconstruction" - # contingency_analysis = "contingency_analysis" def get_class_from_enum( self, @@ -42,8 +30,6 @@ def get_class_from_enum( match self.value: case TaskTypeEnum.feature_reconstruction.value: return FeatureReconstructionTask - case TaskTypeEnum.contingency_analysis.value: - return ContingencyAnalysisTask case _: raise TypeError("Task type does not exist") @@ -79,15 +65,11 @@ def __post_init__(self): self.type = ParameterTypeEnum(self.type) - optimization_space_type = dict[ str, Union[list, dict, ParameterBounds, "optimization_space_type"] ] - - - @dataclass class HyperParameterOptmizerSpec: """ @@ -108,9 +90,10 @@ class HyperParameterOptmizerSpec: max_run_duration (str, None): maximum allowed run duration in the form DD:HH:MM:SS; will stop a run after this amount of time. Defaults to None, which doesn't stop runs by time. """ + experiment_name: str run_name: str - + results_folder: str save_models: bool = False n_trials: int = 5 @@ -124,14 +107,13 @@ class HyperParameterOptmizerSpec: optimization_space: dict | None = None - - @dataclass class TrainingSpec: """ Parameters passed to define lightning trainer - + """ + batch_size: int epochs: int losses: list[str] @@ -143,14 +125,13 @@ class TrainingSpec: enable_progress_bar: bool = False - - @dataclass class ModelSpec: """ Parameters passed to define Model - + """ + attention_head: int dropout: float edge_dim: int @@ -163,14 +144,13 @@ class ModelSpec: model_path: str - - @dataclass class OptimizerSpec: """ - Parameters passed to define Optimization and Scheduling parameters. Learning rate will be overwritten for 'iterate' subcommand. - + Parameters passed to define Optimization and Scheduling parameters. Learning rate will be overwritten for 'iterate' subcommand. + """ + learning_rate: float type: str optimizer_params: dict @@ -178,13 +158,13 @@ class OptimizerSpec: scheduler_params: dict | None - @dataclass class DataSpec: """ - Parameters passed to define training data. Ignored for 'iterate' subcommand. - + Parameters passed to define training data. Ignored for 'iterate' subcommand. + """ + networks: list[str] scenarios: list[int] normalization: str @@ -200,7 +180,6 @@ class DataSpec: data_path: str - @dataclass class CallbackSpec: """ @@ -209,16 +188,15 @@ class CallbackSpec: Args: patience (int): patience for early stopping tol (int): ... - + """ - #TODO: use dicts for each callback type + + # TODO: use dicts for each callback type patience: int | None = None tol: int | None = None max_run_duration: int | None = None monitor_learning_rate: bool = True - optuna_early_prune: bool = False #only processed with iterate command - - + optuna_early_prune: bool = False # only processed with iterate command @dataclass @@ -236,9 +214,10 @@ class TaskSpec: data: datamodule (BaseDataModule | GeoBenchDataModule): Datamodule to be used. """ + name: str type: TaskTypeEnum = field(repr=False) - data: DataSpec # = field(repr=False) + data: DataSpec # = field(repr=False) metric: str = "val/constraint_violations" direction: Literal["min", "max"] = "min" @@ -254,4 +233,4 @@ def recursive_merge(first_dict: dict[str, Any], second_dict: dict[str, Any]): recursive_merge(first_dict[key], val) # if it is not further nested, just replace the value else: - first_dict[key] = val \ No newline at end of file + first_dict[key] = val From 9998ba28ea89f2e71afeec72f3cc7ab61a9a1161 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Tue, 23 Dec 2025 15:45:41 -0500 Subject: [PATCH 11/16] formatting Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/cli.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 91bb93b..fbd236c 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -17,10 +17,6 @@ DEFAULT_SEED = 42 - - - - def main_cli(args): logger = MLFlowLogger( save_dir=args.log_dir, @@ -106,25 +102,18 @@ def main_cli(args): print(f"Saved predictions to {csv_path}") - - - - def iterate_cli(config_args): - #validate inputs + # validate inputs if config_args.seed is not None: - assert isinstance(config_args.seed, int), ( - "seed must be an integer" - ) + assert isinstance(config_args.seed, int), "seed must be an integer" else: config_init.seed = DEFAULT_SEED torch.manual_seed(config_args.seed) random.seed(config_args.seed) np.random.seed(config_args.seed) - run_iterate_experiments( - args=config_args, #TODO + args=config_args, # TODO model_spec=config_args.model, training_spec=config_args.training, optimizer_spec=config_args.optimizer, @@ -132,5 +121,4 @@ def iterate_cli(config_args): hpo_spec=config_args.hpo_spec, tasks=config_args.tasks, seed=config_args.seed, - ) - + ) From 56dc7fbe6908d4eff90f62e5c981ba096036b601 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Wed, 24 Dec 2025 15:22:56 -0500 Subject: [PATCH 12/16] update readme Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index 78e8661..d80ca67 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ Available commands: * `finetune` – Fine-tune an existing pre-trained model * `evaluate` – Evaluate model performance on a dataset * `predict` – Run inference and save predictions +* `iterate` – Run hyperparameter optimization (HPO) and repeated experiments on multiple datasets/tasks --- @@ -144,3 +145,28 @@ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model. | `--output_path` | `str` | Directory where predictions are saved. | `data` | --- + +--- + +## Running Iterate + +```bash +gridfm_graphkit iterate --config path/to/config.yaml +``` + +### Arguments + +| Argument | Type | Description | Default | +| --------------- | ----- | --------------------------------------------- | --------- | +| `--config` | `str` | Path to `iterate` config file. | `None` | +| `--seed` | `int`. | Seed for reproducibility. | `None` | +| `--hpo_spec` | `namespace` | Parameters for HPO/repeated experiments | `None` | +| `--tasks` | `namespace` | MLflow run name. | `None` | +| `--model` | `namespace` | MLflow logging directory. | `None` | +| `--optimizer` | `namespace` | Dataset directory. | `None` | +| `--training` | `namespace` | Directory where predictions are saved. | `None` | +| `--callbacks` | `namespace` | Directory where predictions are saved. | `None` | + +--- +**Note:** Namespace inputs can be provided in the config or as command line arguments. If provided on the command line, namespaces inputs can be provided with `.` notation, e.g. `--hpo_spec.experiment_name my_exp`. Run `gridfm_graphkit iterate -h` for full list of inputs allows in each namespace. + From c1b104ccd96451252ba88da7bdeecabdabe87d25 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Wed, 24 Dec 2025 15:24:15 -0500 Subject: [PATCH 13/16] add returns Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/__main__.py | 19 +++++++++++++++---- gridfm_graphkit/cli.py | 6 +----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index d7b59fb..6a23fd0 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -36,6 +36,16 @@ def main(): finetune_parser.add_argument("--log_dir", type=str, default="mlruns") finetune_parser.add_argument("--data_path", type=str, default="data") + # ---- EVALUATE SUBCOMMAND ---- + # evaluate_parser = subparsers.add_parser("evaluate", help="Evaluate model performance") + evaluate_parser = ArgumentParser() + evaluate_parser.add_argument("--config", type=str, required=True) + evaluate_parser.add_argument("--model_path", type=str, required=True) + evaluate_parser.add_argument("--exp_name", type=str, default=exp_name) + evaluate_parser.add_argument("--run_name", type=str, default="run") + evaluate_parser.add_argument("--log_dir", type=str, default="mlruns") + evaluate_parser.add_argument("--data_path", type=str, default="data") + # ---- PREDICT SUBCOMMAND ---- # predict_parser = subparsers.add_parser("predict", help="Evaluate model performance") predict_parser = ArgumentParser() @@ -50,6 +60,7 @@ def main(): # ---- ITERATE SUBCOMMAND ---- # iterate_parser = subparsers.add_parser("iterate", help="Run model benchmarking") iterate_parser = ArgumentParser() + iterate_parser.add_argument("--config", action="config") iterate_parser.add_argument("--seed", type=int) iterate_parser.add_argument("--hpo_spec", type=HyperParameterOptmizerSpec) iterate_parser.add_argument("--tasks", type=list[TaskSpec]) @@ -57,7 +68,7 @@ def main(): iterate_parser.add_argument("--optimizer", type=OptimizerSpec) iterate_parser.add_argument("--training", type=TrainingSpec) iterate_parser.add_argument("--callbacks", type=CallbackSpec) - iterate_parser.add_argument("--config", action="config") + parser = ArgumentParser( prog="gridfm_graphkit", @@ -66,14 +77,14 @@ def main(): subcommands = parser.add_subcommands() subcommands.add_subcommand('train', train_parser) subcommands.add_subcommand('finetune', finetune_parser) + subcommands.add_subcommand('evaluate', finetune_parser) subcommands.add_subcommand('predict', predict_parser) subcommands.add_subcommand('iterate', iterate_parser) args = parser.parse_args() if args.subcommand == "iterate": - # config = args.iterate.config - # config_args: Namespace = iterate_parser.instantiate_classes(config) - iterate_cli(args.iterate) + experiment_ids = iterate_cli(args.iterate) + return experiment_ids else: main_cli(args) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index fbd236c..9694fd2 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -14,8 +14,6 @@ from jsonargparse import Namespace -DEFAULT_SEED = 42 - def main_cli(args): logger = MLFlowLogger( @@ -106,13 +104,11 @@ def iterate_cli(config_args): # validate inputs if config_args.seed is not None: assert isinstance(config_args.seed, int), "seed must be an integer" - else: - config_init.seed = DEFAULT_SEED torch.manual_seed(config_args.seed) random.seed(config_args.seed) np.random.seed(config_args.seed) - run_iterate_experiments( + return run_iterate_experiments( args=config_args, # TODO model_spec=config_args.model, training_spec=config_args.training, From 79fbe90de2f205198a9c9d3017df681fca438fa9 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Wed, 24 Dec 2025 20:19:16 -0500 Subject: [PATCH 14/16] repeated metric Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/iterate/hpo.py | 336 +++++++----------- gridfm_graphkit/iterate/model_fitting.py | 62 ++-- .../tasks/feature_reconstruction_task.py | 11 +- 3 files changed, 174 insertions(+), 235 deletions(-) diff --git a/gridfm_graphkit/iterate/hpo.py b/gridfm_graphkit/iterate/hpo.py index d55ab97..ffc6109 100644 --- a/gridfm_graphkit/iterate/hpo.py +++ b/gridfm_graphkit/iterate/hpo.py @@ -4,11 +4,12 @@ from jsonargparse import Namespace - - +import warnings +import time +from random import randint from functools import partial -from typing import Any, Dict +from typing import Dict import mlflow import optuna @@ -16,21 +17,24 @@ import torch from optuna.pruners import HyperbandPruner from optuna.samplers import BaseSampler, RandomSampler +from ast import literal_eval -from gridfm_graphkit.iterate.model_fitting import fit_model, fit_model_with_hparams +from gridfm_graphkit.iterate.model_fitting import ( + fit_model, + fit_model_with_hparams, + inject_hparams, +) from gridfm_graphkit.utils.types import ( - HyperParameterOptmizerSpec, - TaskSpec, + HyperParameterOptmizerSpec, + TaskSpec, CallbackSpec, - OptimizerSpec, - ModelSpec, - TrainingSpec, - DataSpec, + OptimizerSpec, + ModelSpec, + TrainingSpec, direction_type_to_optuna, - optimization_space_type - - ) + optimization_space_type, +) from gridfm_graphkit.iterate.utils import ( parse_optimization_space, @@ -39,12 +43,11 @@ unflatten, get_logger, sync_mlflow_optuna, - ) - +) def run_iterate_experiments( - args, #TODO: remove + args, # TODO: remove model_spec: ModelSpec, training_spec: TrainingSpec, optimizer_spec: OptimizerSpec, @@ -52,17 +55,26 @@ def run_iterate_experiments( hpo_spec: HyperParameterOptmizerSpec, tasks: list[TaskSpec], seed: int = 42, - ) -> bool: - """ - runs full benchmarking (hpo + repeated) for a model across multiple tasks +) -> dict[str, str, str]: + """Runs full benchmarking (hpo + repeated) for a model across multiple tasks. Args: - + args: Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + model_spec: ModelSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + training_spec: TrainingSpec, Trainer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + optimizer_spec: OptimizerSpec, Optimizer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + callbacks_spec: CallbackSpec, Callbacks configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + hpo_spec: HyperParameterOptmizerSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + tasks: list[TaskSpec], Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + seed: int = 42, seed lightning Return: - + Dict: + hpo_experiment_id: str, + hpo_finished_run_id: str + repeated_experiment_id: str """ - #create folders and initialize logger + # create folders and initialize logger base = Path(args.hpo_spec.results_folder) HPO_EXP_FOLDER = base / "hpo_mlflow_output" REPEATED_EXP_FOLDER = base / "repeated_mlflow_output" @@ -72,12 +84,12 @@ def run_iterate_experiments( for f in folders: os.makedirs(str(f), exist_ok=True) logger = get_logger(log_folder=str(LOG_FOLDER)) + experiment_ids = {} - benchmarking_completed = False try: - # run hpo on model across multiple tasks + # run hpo on model across multiple tasks hpo_output = run_hpo_experiments( - args=args, #TODO: remove args + args=args, # TODO: remove args logger=logger, model_spec=model_spec, training_spec=training_spec, @@ -87,12 +99,13 @@ def run_iterate_experiments( tasks=tasks, seed=seed, storage_uri=HPO_EXP_FOLDER, - ) - + ) + + if args.hpo_spec.num_repetitions >= 1: # run repeated experiments - run_repeated_experiments( - args=args, #TODO: remove args + repeated_output = run_repeated_experiments( + args=args, # TODO: remove args logger=logger, model_spec=model_spec, training_spec=training_spec, @@ -104,17 +117,17 @@ def run_iterate_experiments( repeated_storage_uri=REPEATED_EXP_FOLDER, hpo_storage_uri=HPO_EXP_FOLDER, csv_folder=REPEATED_CSV_FOLDER, - experiment_id=hpo_output["experiment_id"], - parent_run_id=hpo_output["finished_run_id"], - ) + hpo_parent_run_id=hpo_output["hpo_finished_run_id"], + ) + hpo_output["repeated_experiment_id"] = repeated_output + return hpo_output except Exception as e: logger.info(f"Could not complete due to error {e}") raise - def run_hpo_experiments( - args: Namespace, #TODO: remove + args: Namespace, # TODO: remove logger: logging.RootLogger, model_spec: ModelSpec, training_spec: TrainingSpec, @@ -125,34 +138,37 @@ def run_hpo_experiments( seed: int, storage_uri: Path, ) -> Dict[str, str]: - """Highest level function to run hpo only for a model across multiple tasks + """Highest level function to run hpo only for a model across multiple tasks. Args: - args: Namespace, #TODO: remove - logger: logging.RootLogger, - model_spec: ModelSpec, - training_spec: TrainingSpec, - optimizer_spec: OptimizerSpec, - callbacks_spec: CallbackSpec, - hpo_spec: HyperParameterOptmizerSpec, - tasks: list[TaskSpec], - seed: int, - storage_uri: Path, - + args: Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + logger: logging.RootLogger, Logger for experiment + model_spec: ModelSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + training_spec: TrainingSpec, Trainer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + optimizer_spec: OptimizerSpec, Optimizer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + callbacks_spec: CallbackSpec, Callbacks configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + hpo_spec: HyperParameterOptmizerSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + tasks: list[TaskSpec], Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + seed: int = 42, seed lightning + storage_uri: Path, location to store mlflow output from HPO Return: - - + Dict: + hpo_experiment_id: str, + hpo_finished_run_id: str """ # https://mlflow.org/docs/latest/ml/tracking/system-metrics/#using-the-environment-variable-to-control-system-metrics-logging if os.getenv("MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING") is None: os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" - model_type: str = model_spec.type run_id: str = hpo_spec.run_id experiment_name: str = hpo_spec.experiment_name task_names = [task.name for task in tasks] - run_name = f"top_run_{hpo_spec.experiment_name}" if hpo_spec.run_name is None else hpo_spec.run_name + run_name = ( + f"top_run_{hpo_spec.experiment_name}" + if hpo_spec.run_name is None + else hpo_spec.run_name + ) optimization_space = parse_optimization_space(hpo_spec.optimization_space) completed_task_run_names = [] optimize_hyperparams = True @@ -205,15 +221,15 @@ def run_hpo_experiments( # only run hyperparameter optimization (HPO) if there are no experiments with finished HPO if optimize_hyperparams: if hpo_spec.bayesian_search: - sampler = None # defaults to TPESampler + sampler = None # defaults to TPESampler else: sampler = RandomSampler() experiment_id, finished_run_id = _run_hpo( - args=args, #TODO + args=args, # TODO model_spec=model_spec, training_spec=training_spec, optimizer_spec=optimizer_spec, - callbacks_spec=callbacks_spec, + callbacks_spec=callbacks_spec, run_name=run_name, run_id=run_id, tasks=tasks, @@ -230,10 +246,8 @@ def run_hpo_experiments( logger=logger, seed=seed, ) - logger.info("HPO complete") - return {"experiment_id": experiment_id, "finished_run_id": finished_run_id} - - + logger.info("HPO complete\n\n\nß") + return {"hpo_experiment_id": experiment_id, "hpo_finished_run_id": finished_run_id} def _run_hpo( @@ -259,8 +273,7 @@ def _run_hpo( test_models: bool = False, seed: int = 42, ) -> tuple[str, str]: - """ - run HPO for multiple tasks under a single experiment. + """Run HPO for multiple tasks under a single experiment. Args: arg: Namespace, contains all parameters to be passed to model and datamodule. To be removed @@ -287,11 +300,8 @@ def _run_hpo( """ - logger.info( - f"Running hyperparameter optimization: {run_name=} {run_id=}" - ) - if run_id is not None: - run_name = None + logger.info(f"Running hyperparameter optimization: {run_name=} {run_id=}") + storage_uri = str(storage_uri) with mlflow.start_run( run_name=run_name, run_id=run_id, description=description @@ -311,14 +321,14 @@ def _run_hpo( else None ) best_value, metric_name, hparams = _run_hpo_per_task( - args=args, #TODO + args=args, # TODO model_spec=model_spec, training_spec=training_spec, optimizer_spec=optimizer_spec, - callbacks_spec=callbacks_spec, + callbacks_spec=callbacks_spec, logger=logger, task=task, - storage_uri=str(storage_uri), + storage_uri=storage_uri, experiment_name=experiment_name, experiment_run_id=run.info.run_id, task_run_id=task_run_id, @@ -349,12 +359,8 @@ def _run_hpo( return experiment_id, finished_run_id - - - - def _run_hpo_per_task( - args: Namespace, #TODO: remove args + args: Namespace, # TODO: remove args model_spec: ModelSpec, training_spec: TrainingSpec, optimizer_spec: OptimizerSpec, @@ -371,7 +377,7 @@ def _run_hpo_per_task( sampler: BaseSampler | None = None, test_models: bool = False, seed: int = 42, -): +) -> tuple[str, float, dict]: """ Performs HPO on a single task @@ -427,7 +433,7 @@ def _run_hpo_per_task( logger.info(f"starting task run with id: {run.info.run_id}") if training_spec.epochs is None: raise Exception("Must specify epochs for training") - + # if no optimization params, just run it if optimization_space is None: return ( @@ -510,9 +516,8 @@ def _run_hpo_per_task( return study.best_value, task.metric, best_params - def run_repeated_experiments( - args: Namespace, #TODO: remove args + args: Namespace, # TODO: remove args logger: logging.RootLogger, model_spec: ModelSpec, training_spec: TrainingSpec, @@ -524,30 +529,41 @@ def run_repeated_experiments( repeated_storage_uri: Path, hpo_storage_uri: Path, csv_folder: Path, - experiment_id: str, - parent_run_id: str, -): - """Repeat best experiments from a benchmark run. Only works with a ray cluster. + hpo_parent_run_id: str, +)-> str: + """Repeat best experiments from a benchmark run across multiple tasks. Args: + args: Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + logger: logging.RootLogger, Logger for experiment + model_spec: ModelSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + training_spec: TrainingSpec, Trainer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + optimizer_spec: OptimizerSpec, Optimizer configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + callbacks_spec: CallbackSpec, Callbacks configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + hpo_spec: HyperParameterOptmizerSpec, Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + tasks: list[TaskSpec], Model configuration. run "gridfm_graphkit iterate -h" to see allowed inputs + seed: int = 42, seed lightning + repeated_storage_uri: Path, location to store mlflow output from repeated experiments + hpo_storage_uri: Path, location where mlflow output from completed HPO experiments is stored + csv_folder: Path, Location where csv files with repeated experimnets are stored + hpo_parent_run_id: str, run id of successful HPO experiment - + Return: + str: experiment_id of completed Repeated experiment """ - - # if backbone_import: - # importlib.import_module(backbone_import) + logger.info("Starting repeated experiments") experiment_name = hpo_spec.experiment_name num_repetitions = hpo_spec.num_repetitions - #find completed HPO tasks + # find completed HPO tasks mlflow.set_tracking_uri(str(hpo_storage_uri)) mlflow.set_experiment(experiment_name) runs: list[mlflow.entities.Run] = mlflow.search_runs( - filter_string=f"tags.mlflow.parentRunId='{parent_run_id}'", output_format="list" + filter_string=f"tags.mlflow.parentRunId='{hpo_parent_run_id}'", output_format="list" ) # type: ignore - logger.info(f"parent_run_id {parent_run_id}") + logger.info(f"hpo_parent_run_id {hpo_parent_run_id}") logger.info(f"Found runs: {[run.info.run_name for run in runs]}") task_names = [task.name for task in tasks] @@ -561,16 +577,12 @@ def run_repeated_experiments( "mlflow_run_id", "mlflow_run_status", ] - table_entries = [] mlflow.set_tracking_uri(repeated_storage_uri) mlflow.set_experiment(experiment_name) - experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id output_path = csv_folder / f"{experiment_name}_repeated_exp_mlflow.csv" if not os.path.isabs(output_path): - raise Exception( - f"output_path must be absolute." - ) + raise Exception(f"output_path must be absolute. got: {output_path}") # backbone_name = defaults.terratorch_task["model_args"]["backbone"] with mlflow.start_run(run_name=experiment_name, run_id=None) as run: @@ -612,9 +624,10 @@ def run_repeated_experiments( best_params = matching_runs[0].data.params best_params = {k: literal_eval(v) for k, v in best_params.items()} - training_spec = combine_with_defaults(task, defaults) - lightning_task_class = training_spec.task.type.get_class_from_enum() - + training_spec_with_best_hparams, optimizer_spec_with_best_hparams = ( + inject_hparams(training_spec, optimizer_spec, best_params) + ) + experiment_info = mlflow.get_experiment_by_name(experiment_name) seeds = [randint(1, 5000) for i in range(num_repetitions * 5)] seeds = [seed for seed in seeds if seed not in past_seeds] @@ -631,21 +644,28 @@ def run_repeated_experiments( output_format="list", ) # type: ignore if len(seed_run_data) > 0: - continue + for item in seed_run_data: + logger.info(f"deleting existing run: {item}") + mlflow.delete_run(item.info.run_id) - score = non_remote_fit( - experiment_name=repeated_experiment_name, - parent_run_id=run.info.run_id, - storage_uri=repeated_storage_uri, - task=task, + score = fit_model( + args=args, + model_spec=model_spec, training_spec=training_spec, - lightning_task_class=lightning_task_class, - best_params=best_params, + optimizer_spec=optimizer_spec, + callbacks_spec=callbacks_spec, + task=task, + run_name=seed_run_name, + experiment_name=experiment_name, + storage_uri=repeated_storage_uri, + parent_run_id=run.info.run_id, + save_models=hpo_spec.save_models, + test_models=True, seed=seed, - backbone_import=backbone_import, - save_models=save_models, - report_on_best_val=report_on_best_val, + logger=logger, + # repeat_on_best=hpo_spec.repeat_on_best, ) + # check if run with name finished successfully logger.info(f"score: {score}") # TODO improve this sleep command - try to get a better estimate than this @@ -690,9 +710,7 @@ def run_repeated_experiments( logger.info(f"rows: {rows} \t cols: {cols}") if rows == 0: logger.info("no past results for this task") - existing_output = pd.concat( - [existing_output, new_data], axis=0 - ) + existing_output = pd.concat([existing_output, new_data], axis=0) existing_output.reset_index(inplace=True) existing_output = existing_output.drop( columns=["index", "level_0"] @@ -700,101 +718,5 @@ def run_repeated_experiments( existing_output.to_csv(output_path, index=False) else: new_data.to_csv(output_path, index=False) - - -def _run_repeated_per_task( - args: Namespace, #TODO: remove args - model_spec: ModelSpec, - training_spec: TrainingSpec, - optimizer_spec: OptimizerSpec, - callbacks_spec: CallbackSpec, - logger: logging.RootLogger, - task: TaskSpec, - storage_uri: str, - experiment_name: str, - experiment_run_id: str, - task_run_id: str | None = None, - optimization_space: optimization_space_type | None = None, - n_trials: int = 1, - save_models: bool = False, - sampler: BaseSampler | None = None, - test_models: bool = False, - seed: int = 42, -): - """ - Performs HPO on a single task - - Args: - args: Namespace, #TODO: remove args - model_spec: ModelSpec, - training_spec: TrainingSpec, - optimizer_spec: OptimizerSpec, - callbacks_spec: CallbackSpec, - logger: logging.RootLogger, - task: TaskSpec, - storage_uri: str, - experiment_name: str, - experiment_run_id: str, - task_run_id: str | None = None, - optimization_space: optimization_space_type | None = None, - n_trials: int = 1, - save_models: bool = False, - sampler: BaseSampler | None = None, - test_models: bool = False, - seed: int = 42, - - """ - logger.info( - f"starting backbone benchmark on task {task.name} {task_run_id=} {experiment_name=}" - ) - if storage_uri.startswith("http"): - optuna_db_path = Path(".") / "optuna_db" - else: - optuna_db_path = Path(storage_uri).parents[0] / "optuna_db" - - if not os.path.exists(optuna_db_path): - os.makedirs(optuna_db_path) - optuna_db_path = optuna_db_path / f"{experiment_name}_{experiment_run_id}" - optuna_db_path = str(optuna_db_path) - - task_run_id = sync_mlflow_optuna( - optuna_db_path=optuna_db_path, - storage_uri=storage_uri, - experiment_name=experiment_name, - task_run_id=task_run_id, - task=task, - n_trials=n_trials, - logger=logger, - ) - if task_run_id is not None: - # run_name is used only when run_id is unspecified. - run_name = None - else: - run_name = task.name - logger.info(f"start run: {run_name=} {task_run_id=}") - with mlflow.start_run(run_name=run_name, nested=True, run_id=task_run_id) as run: - logger.info(f"starting task run with id: {run.info.run_id}") - if training_spec.epochs is None: - raise Exception("Must specify epochs for training") - - - # if no optimization params, just run it - if optimization_space is None: - return ( - *fit_model( - args=args, - model_spec=model_spec, - training_spec=training_spec, - optimizer_spec=optimizer_spec, - callbacks_spec=callbacks_spec, - task=task, - run_name=run.info.run_name, - experiment_name=experiment_name, - storage_uri=storage_uri, - parent_run_id=run.info.run_id, - save_models=save_models, - test_models=test_models, - seed=seed, - logger=logger, - ), - ) + logger.info("Repeated experiments complete \n\n\n") + return experiment_info.experiment_id diff --git a/gridfm_graphkit/iterate/model_fitting.py b/gridfm_graphkit/iterate/model_fitting.py index 1a1241b..68cbbeb 100644 --- a/gridfm_graphkit/iterate/model_fitting.py +++ b/gridfm_graphkit/iterate/model_fitting.py @@ -42,13 +42,10 @@ from gridfm_graphkit.tasks import ( FeatureReconstructionTask, - # ContingencyAnalysisTask, ) -from gridfm_graphkit.iterate.utils import get_logger, get_best_val - -LOGGER = get_logger() +from gridfm_graphkit.iterate.utils import get_best_validation_metric, get_test_metric os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = ( @@ -173,11 +170,6 @@ def _generate_parameters( ) -""" -single node - optuna -""" - - def launch_training( trainer: Trainer, model: FeatureReconstructionTask, # TODO: create basetask in tasks folder @@ -205,16 +197,27 @@ def launch_training( ) trainer.fit(model, datamodule=datamodule) - output = get_best_val( - storage_uri=storage_uri, - run=run, - metric=metric, - direction=direction, - ) - if test_models: - test_metrics = trainer.test(model, ckpt_path="best", datamodule=datamodule) - output = test_metrics + trainer.test( + model=model, + # ckpt_path="best", + datamodule=datamodule + ) + metric = metric.replace("Validation", "Test") + output = get_test_metric( + storage_uri=str(storage_uri), + run=run, + metric=metric, + direction=direction, + ) + else: + output = get_best_validation_metric( + storage_uri=str(storage_uri), + run=run, + metric=metric, + direction=direction, + ) + if delete_models_after_testing: # delete the checkpoints folder in the run ckpts_folder = os.path.join( @@ -244,7 +247,7 @@ def fit_model( test_models: bool = False, seed: int = 42, trial: optuna.Trial | None = None, -) -> dict: +) -> float: pl.seed_everything(seed, workers=True) training_spec_copy = copy.deepcopy(training_spec) @@ -264,11 +267,25 @@ def fit_model( # we need to save the models during training to be able to test but can be deleted afterwards save_models = True delete_models_after_testing = True + if save_models: - callbacks.append(ModelCheckpoint(monitor=task.metric, mode=task.direction)) - enable_checkpointing = False - if any([isinstance(cb, ModelCheckpoint) for cb in callbacks]): + callbacks = [ + cb + for cb in callbacks + if not (isinstance(cb, ModelCheckpoint) and cb.monitor==task.metric) + ] + callbacks.append( + ModelCheckpoint( + monitor=task.metric, + mode=task.direction, + save_top_k=1, + ) + ) enable_checkpointing = True + else: + callbacks = [cb for cb in callbacks if not isinstance(cb, ModelCheckpoint)] + enable_checkpointing = False + # # initialize datamodule args.data = task.data @@ -296,6 +313,7 @@ def fit_model( callbacks=callbacks, enable_checkpointing=enable_checkpointing, enable_progress_bar=training_spec_copy.enable_progress_bar, + # deterministic=True, ) logger.info( diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index d156c4d..cff1f23 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -1,5 +1,4 @@ import torch -from torch.optim.lr_scheduler import ReduceLROnPlateau import lightning as L from pytorch_lightning.utilities import rank_zero_only import numpy as np @@ -45,7 +44,7 @@ class FeatureReconstructionTask(L.LightningModule): predict_step(batch, batch_idx, dataloader_idx=0): Run inference and return denormalized outputs + node masks. configure_optimizers(): - Setup Adam optimizer and ReduceLROnPlateau scheduler. + Setup optimizer and scheduler. on_fit_start(): Save normalization statistics at the beginning of training. on_test_end(): @@ -110,7 +109,7 @@ def on_fit_start(self): ) def shared_step(self, batch): - output = self.forward( + output = self( x=batch.x, pe=batch.pe, edge_index=batch.edge_index, @@ -224,7 +223,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Test loss"] = loss_dict.pop("loss").detach() for metric, value in loss_dict.items(): - metric_name = f"{dataset_name}/{metric}" + metric_name = f"Test - {dataset_name}/{metric}" if "p.u." in metric: # Denormalize metrics expressed in p.u. value *= self.node_normalizers[dataloader_idx].baseMVA @@ -235,7 +234,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): batch_size=batch.num_graphs, add_dataloader_idx=False, sync_dist=True, - logger=False, + logger=True, ) return @@ -283,7 +282,7 @@ def on_test_end(self): pass if "/" in full_key: - dataset_name, metric = full_key.split("/", 1) + dataset_name, metric = full_key.replace("Test - ", "").split("/", 1) if dataset_name not in grouped_metrics: grouped_metrics[dataset_name] = {} grouped_metrics[dataset_name][metric] = value From 0f3145b8c3f198aa5f20cfe294ed1f108fc6d546 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Wed, 24 Dec 2025 20:20:22 -0500 Subject: [PATCH 15/16] add utils Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- examples/config/case30_ieee_base_hpo.yaml | 4 +- gridfm_graphkit/iterate/utils.py | 48 ++++++++++++++++++++++- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/examples/config/case30_ieee_base_hpo.yaml b/examples/config/case30_ieee_base_hpo.yaml index f0ba5cd..9169b5b 100644 --- a/examples/config/case30_ieee_base_hpo.yaml +++ b/examples/config/case30_ieee_base_hpo.yaml @@ -47,7 +47,7 @@ callbacks: optuna_early_prune: True hpo_spec: - experiment_name: 1#GPSTransformer_test_exp + experiment_name: GPSTransformer_8 run_name: top_run optimization_space: batch_size: [8, 16, 32] @@ -65,7 +65,7 @@ hpo_spec: report_on_best_val: True continue_existing_experiment: True -tasks: #cannot overwrite any parameters here +tasks: - name: feature_reconstruction_base type: feature_reconstruction metric: "Validation loss" diff --git a/gridfm_graphkit/iterate/utils.py b/gridfm_graphkit/iterate/utils.py index 1e8a7cc..3ea0df8 100644 --- a/gridfm_graphkit/iterate/utils.py +++ b/gridfm_graphkit/iterate/utils.py @@ -70,7 +70,7 @@ def get_logger(log_level="INFO", log_folder="./experiment_logs") -> logging.Root return logger -def get_best_val( +def get_best_validation_metric( storage_uri: str, run: mlflow.entities.Run, metric: str, @@ -125,6 +125,49 @@ def get_best_val( return series_val_metrics[(metric, best_step)] + + +def get_test_metric( + storage_uri: str, + run: mlflow.entities.Run, + metric: str, + direction: str, +): + client = mlflow.tracking.MlflowClient( + tracking_uri=storage_uri, + ) + + if not metric.lower().startswith("test"): + raise Exception( + f"Metric {metric} does not start with `test`. Please choose a test metric" + ) + for_pd_collect = [] + test_metrics_names = [] + + print(f'{client.get_run(run.info.run_id)=}') + metric_value_across_datasets = [] + + for metric_log_name in client.get_run(run.info.run_id).data.metrics: + if "Test - " not in metric_log_name: + continue + dataset_name, metric_name = metric_log_name.replace("Test - ", "").split("/", 1) + if metric == metric_name: + test_metrics_names.append(metric_name) + test_metric_history = client.get_metric_history(run.info.run_id, metric_log_name) + pd_convertible_metric_history = [ + { + "metric_name": mm.key, + "step": mm.step, + "value": mm.value, + } + for mm in test_metric_history + ] + for_pd_collect += pd_convertible_metric_history + df_test_metrics = pd.DataFrame.from_records(for_pd_collect) + metric_value = df_test_metrics.loc[:, 'value'].mean() + return metric_value + + def check_existing_experiments( logger: logging.RootLogger, storage_uri: str, @@ -163,7 +206,8 @@ def check_existing_experiments( return output experiment_id = experiment_info.experiment_id - logger.info(f"\nexperiment_id: {experiment_id}") + logger.info(f"Checking existing experiment") + logger.info(f"experiment_id: {experiment_id}") logger.info(f"experiment_name: {experiment_name}") output["experiment_id"] = experiment_id experiment_parent_run_data = client.search_runs( From 517d87bae583568e1a59716c2225d6b66c03b0cf Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Wed, 24 Dec 2025 20:21:33 -0500 Subject: [PATCH 16/16] tests Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- .../config/iterate_test_case30_ieee_base.yaml | 73 ++++++++ tests/test_iterate.py | 173 ++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 tests/config/iterate_test_case30_ieee_base.yaml create mode 100644 tests/test_iterate.py diff --git a/tests/config/iterate_test_case30_ieee_base.yaml b/tests/config/iterate_test_case30_ieee_base.yaml new file mode 100644 index 0000000..7d56216 --- /dev/null +++ b/tests/config/iterate_test_case30_ieee_base.yaml @@ -0,0 +1,73 @@ +seed: 42 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 256 + input_dim: 9 + num_layers: 8 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer + model_path: "gridfm-graphkit/examples/models/GridFM_v0_2.pth" +training: + batch_size: 1 + epochs: 1 + losses: ["MaskedMSE", "PBE"] + loss_weights: [0.01, 0.99] + accelerator: auto + devices: auto + strategy: auto +optimizer: + type: Adam + learning_rate: 0.0001 + optimizer_params: + betas: [0.9, 0.999] + scheduler_type: ReduceLROnPlateau + scheduler_params: + mode: min + factor: 0.7 + patience: 10 +callbacks: + patience: 100 + tol: 0 + optuna_early_prune: True + +hpo_spec: + experiment_name: GPSTransformer_test_exp + run_name: top_run + optimization_space: + batch_size: [8, 16, 32] + learning_rate: + min: 0.000006 + max: 0.001 + type: real + log: true + n_trials: 1 + bayesian_search: True + results_folder: "results/iterate" + save_models: False + num_repetitions: 1 + repeat_on_best: True + report_on_best_val: True + continue_existing_experiment: True + +tasks: + - name: feature_reconstruction_base + type: feature_reconstruction + metric: "Validation loss" + direction: min + data: + data_path: "data/" + networks: ["case30_ieee"] + scenarios: [1023] + normalization: "baseMVAnorm" + baseMVA: 100 + mask_type: "rnd" + mask_value: 0.0 + mask_ratio: 0.5 + mask_dim: 6 + learn_mask: False + val_ratio: 0.1 + test_ratio: 0.1 + workers: 4 \ No newline at end of file diff --git a/tests/test_iterate.py b/tests/test_iterate.py new file mode 100644 index 0000000..60ed6b6 --- /dev/null +++ b/tests/test_iterate.py @@ -0,0 +1,173 @@ +import itertools +from pathlib import Path + +import yaml +from gridfm_graphkit.__main__ import main +import pytest +import sys + +CONFIG_FILES = [ + "configs/iterate_test_case30_ieee_base.yaml", +] + +MODELS = ["examples/models/GridFM_v0_2.pth"] +INPUT_TEST_MAIN = list(itertools.product(MODELS, CONFIG_FILES)) + + +def get_test_ids() -> list[str]: + test_case_ids = list() + for model, config in INPUT_TEST_MAIN: + # get the filename + model = model.split("/")[-1].replace(".pth", "") + config = config.split("/")[-1].replace(".yaml", "") + # append to list of test ids + test_case_ids.append(f"{config}_{model}") + return test_case_ids + + +def validate_hpo_results( + experiment_name: str, + results_folder: Path, + n_trials: int, + n_tasks: int, + iterate_info: dict, + ): + # check that experiment was created + mlflow_output_path = / "hpo_mlflow_output" + assert mlflow_output_path.exists(), f"Error! Directory does not exist: {mlflow_output_path}" + hpo_exp_path = mlflow_output_path / iterate_info["hpo_experiment_id"] + assert hpo_exp_path.exists(), f"Error! Directory does not exist: {hpo_exp_path}" + meta_yaml_path = hpo_exp_path / "meta.yaml" + assert meta_yaml_path.exists(), ( + f"Error! meta.yaml file {meta_yaml_path} does not exist" + ) + + # open file and check that the experiment name/id is the same + experiment_name_found: bool = False + finished_run_id_found: bool = False + experiment_id_found: bool = False + experiment_id = iterate_info["hpo_experiment_id"] + finished_run_id = iterate_info["hpo_finished_run_id"] + with open(meta_yaml_path, mode="r") as f: + # read all the lines + lines = f.readlines() + # try to find experiment id and name in these lines + for line in lines: + if experiment_name in line: + experiment_name_found = True + if finished_run_id in line: + finished_run_id_found = True + if experiment_id in line: + experiment_id_found = True + assert experiment_name_found and experiment_id_found and finished_run_id_found, ( + f"Error! Both experiment name ({experiment_name=}), finished run id ({finished_run_id=}), \ + and experiment id ({experiment_id}) must be in the {meta_yaml_path=}." + ) + + # check number of runs created + expected_num_runs = (n_trials*n_tasks) + n_tasks + 1 + run_folders = [ f.path for f in os.scandir(folder) if f.is_dir() ] + assert len(run_folders)==expected_num_runs, ( + f"Error! Expected {expected_num_runs} to be created for HPO experiment. Found {len(run_folders)} runs." + ) + + + + +def validate_repeated_results( + experiment_name: str, + results_folder: Path, + n_trials: int, + n_tasks: int, + num_repetitions: int, + iterate_info: dict, + ): + # check that epxeriment was created + mlflow_output_path = / "repeated_mlflow_output" + assert mlflow_output_path.exists(), f"Error! Directory does not exist: {mlflow_output_path}" + hpo_exp_path = mlflow_output_path / iterate_info["hpo_experiment_id"] + assert hpo_exp_path.exists(), f"Error! Directory does not exist: {hpo_exp_path}" + meta_yaml_path = hpo_exp_path / "meta.yaml" + assert meta_yaml_path.exists(), ( + f"Error! meta.yaml file {meta_yaml_path} does not exist" + ) + + # open file and check that the experiment name is the same + experiment_name_found: bool = False + experiment_id_found: bool = False + experiment_id = iterate_info["repeated_experiment_id"] + with open(meta_yaml_path, mode="r") as f: + # read all the lines + lines = f.readlines() + # try to find experiment id and name in these lines + + for line in lines: + if experiment_name in line: + experiment_name_found = True + if experiment_id in line: + experiment_id_found = True + assert experiment_name_found and experiment_id_found, ( + f"Error! Both experiment name ({experiment_name=}) and experiment id ({experiment_id=}) \ + must be in the {meta_yaml_path=}." + ) + + # check number of runs created + expected_num_runs = (n_trials*n_tasks) + 1 + run_folders = [ f.path for f in os.scandir(folder) if f.is_dir() ] + assert len(run_folders)==expected_num_runs, ( + f"Error! Expected {expected_num_runs} to be created for repeated experiment. Found {len(run_folders)} runs." + ) + + + + + +@pytest.mark.parametrize( + "model, config", + INPUT_TEST_MAIN, + ids=get_test_ids(), +) +def test_iterate( + model: str, + config: str, +): + test_dir = Path(__file__).parents[0] + home_dir = test_dir.parents[0] + config_file: Path = test_dir / config + assert config_file.exists() + with open(config_file, "r") as file: + config_data = yaml.safe_load(file) + experiment_name = config_data["hpo_spec"]["experiment_name"] + results_folder = config_data["hpo_spec"]["results_folder"] + results_folder = Path(results_folder) + num_repetitions = config_data["hpo_spec"]["num_repetitions"] + n_trials = config_data["hpo_spec"]["n_trials"] + n_tasks = len(config_data["tasks"]) + + + #send command to sys + arguments = ["gridfm_graphkit", "iterate", "--config", str(config_file.resolve())] + model_path = home_dir / model + results_folder = home_dir / "test_reults" + arguments.append["--model.model_path", f"{str(model_path)}"] + arguments.append["--hpo_spec.results_folder", f"{str(results_folder)}"] + + sys.argv = arguments + iterate_info = main() + assert isinstance(iterate_info, dict), f"Error! {iterate_info=} is not a dict" + validate_hpo_results( + experiment_name=experiment_name, + results_folder=results_folder, + n_trials=n_trials, + n_tasks=n_tasks, + iterate_info=iterate_info, + ) + + validate_repeated_results( + experiment_name=experiment_name, + results_folder=results_folder, + num_repetitions=num_repetitions, + n_trials=n_trials, + n_tasks=n_tasks, + iterate_info=iterate_info, + ) \ No newline at end of file