Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions configs/gridfm_graphkit_hpo.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# HPO configuration for gridfm-graphkit HGNS PF case118
#
# Hyperparameters:
# gpu_num – number of GPUs to request from the WLM (launcher-level)
# compute – group: selects gpu_num + num_workers + batch_size together
# so that data-loading and batch sizes scale with GPU count
# bfloat16 – boolean flag (presence = --bfloat16, absence = flag omitted)
# tf32 – boolean flag (presence = --tf32, absence = flag omitted)
# compile – torch.compile mode; null disables the flag entirely
Expand All @@ -14,30 +15,32 @@
# extracted from [performance] lines in trial output

metrics:
- case118_ieee/layer_0_residual
- Validation loss
- last epoch time
- last epoch it/s

hpo:
gpu_num:
type: categorical
choices: [1, 2, 4]

bfloat16:
type: flag # store_true: true → --bfloat16, false → flag omitted

tf32:
type: flag # store_true: true → --tf32, false → flag omitted
compute:
type: group # one choice co-selects gpu_num, num_workers and batch_size
choices:
small: # single GPU – conservative resources
gpu_num: 1
num_workers: 32
batch_size: 64
medium: # two GPUs – doubled throughput
gpu_num: 2
num_workers: 16
batch_size: 32
large: # four GPUs – full-node run
gpu_num: 4
num_workers: 8
batch_size: 16

compile:
type: categorical
choices: ["max-autotune", "default", "reduce-overhead", null]
# null → --compile flag is omitted entirely

num_workers:
type: categorical
choices: [8, 16, 24, 32]

dataset:
type: group # one choice selects all bundled args together
choices:
Expand Down
109 changes: 109 additions & 0 deletions examples/run_lsf_gridfm_example_postgres.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env bash
# =============================================================================
# Example: iterate --wlm lsf with PostgreSQL coordinator for gridfm-graphkit HPO
#
# Each Optuna trial is submitted as an LSF job that looks like:
#
# bsub -gpu "num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB" \
# -K -o trial<N>.out -e trial<N>.err \
# -R "rusage[ngpus=1, cpu=16, mem=32GB]" \
# -J hpo_trial_<N> \
# "export PATH='/opt/share/cuda-12.8.1/bin:$PATH' && \
# export CUDA_HOME='/opt/share/cuda-12.8.1/' && \
# export LD_LIBRARY_PATH='/opt/share/cuda-12.8.1/lib64:$LD_LIBRARY_PATH' && \
# cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && \
# source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && \
# gridfm_graphkit train <hpo_params> <static_params>"
#
# Prerequisites
# -------------
# * LSF bsub/bjobs available on PATH
# * gridfm-graphkit installed in the venv below
# * configs/gridfm_graphkit_hpo.yaml present
# * psycopg2-binary installed: pip install 'terratorch-iterate[postgresql]'
# * POSTGRES_URL set (or hard-code it in --optuna-db-path below)
#
# PostgreSQL coordinator
# ----------------------
# Using PostgreSQL instead of SQLite / JournalFS is the recommended backend for
# high-parallelism HPO on a cluster: multiple bsub jobs can safely write trial
# results concurrently without lock contention.
#
# Set the connection URL as an env-var to avoid embedding credentials in scripts
# that may end up in version control:
#
# export POSTGRES_URL="postgresql://user:password@host:5432/optuna_studies"
#
# or pass it inline:
#
# POSTGRES_URL="postgresql://..." bash run_lsf_gridfm_example_postgres.sh
# =============================================================================

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"

# ---------------------------------------------------------------------------
# Required: PostgreSQL connection URL
# ---------------------------------------------------------------------------
: "${POSTGRES_URL:?Please set POSTGRES_URL=postgresql://user:password@host:port/dbname}"

# ---------------------------------------------------------------------------
# Customisable paths – override via environment variables
# ---------------------------------------------------------------------------
GRIDFM_ROOT="${GRIDFM_ROOT:-/dccstor/terratorch/users/rkie/gitco/gridfm-graphkit}"
GRIDFM_VENV="${GRIDFM_VENV:-/u/rkie/venvs/venv_gridfm-graphkit}"
CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}"
DATA_PATH="${DATA_PATH:-/u/rkie/}"
LOG_DIR="${LOG_DIR:-logs}"

# ---------------------------------------------------------------------------
# LSF GPU resource string
# Adjust gmodel to the GPU type available on your cluster.
# ---------------------------------------------------------------------------
LSF_GPU_CONFIG="${LSF_GPU_CONFIG:-num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB}"

# ---------------------------------------------------------------------------
# Pre-run commands executed inside every bsub job before the training script.
# Order matters:
# 1. Export CUDA paths so the GPU driver / toolkit is visible.
# 2. cd into the project root so relative config paths resolve correctly.
# 3. Activate the project venv.
# ---------------------------------------------------------------------------
PRE_RUN="\
export PATH='${CUDA_BASE}/bin:\$PATH' && \
export CUDA_HOME='${CUDA_BASE}' && \
export LD_LIBRARY_PATH='${CUDA_BASE}/lib64:\$LD_LIBRARY_PATH' && \
cd '${GRIDFM_ROOT}' && \
source '${GRIDFM_VENV}/bin/activate'"

# ---------------------------------------------------------------------------
# Static training arguments (not part of the HPO search space).
# These are appended verbatim after the sampled hyperparameters.
# ---------------------------------------------------------------------------
STATIC_ARGS_JSON='{
"log_dir": "'"${LOG_DIR}"'",
"report-performance": true
}'

# ---------------------------------------------------------------------------
# Launch iterate
# ---------------------------------------------------------------------------
iterate \
--script "gridfm_graphkit train" \
--interpreter "" \
--root-dir "${GRIDFM_ROOT}" \
--wlm lsf \
--pre-run-commands "${PRE_RUN}" \
--no-underscore-to-hyphen \
--gpu-count 1 \
--cpu-count 16 \
--mem-gb 32 \
#--lsf-gpu-config-string "${LSF_GPU_CONFIG}" \
--optuna-study-name gridfm_lsf_postgres_hpo \
--optuna-db-path "${POSTGRES_URL}" \
--parallelism 4 \
--optuna-n-trials 20 \
--hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" \
--static-args-json "${STATIC_ARGS_JSON}"
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ utility = [
# If you want to catch Nvidia GPU metrics, you also need to install pynvml:
nvidia = ["pynvml"]

# PostgreSQL coordinator plugin — installs the psycopg2 driver.
# Use psycopg2-binary for a self-contained wheel (no libpq build dependency).
# For production deployments that compile against a system libpq, replace with
# psycopg2 (without -binary).
postgresql = ["psycopg2-binary>=2.9"]

# If you are using AMD/HIP GPUs, install pyrsmi
amd = ["pyrsmi"]

Expand Down
4 changes: 4 additions & 0 deletions terratorch_iterate/iterate2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# terratorch_iterate.iterate2 package
# Re-export main so that `from terratorch_iterate.iterate2 import main` keeps
# working after iterate2.py was turned into a package directory.
from terratorch_iterate.iterate2._iterate2 import main # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from typing import Dict, Any, Optional, Literal, List

import optuna
from optuna.storages import JournalStorage, JournalFileStorage
import yaml
from terratorch_iterate.iterate2.plugin.coordinator import load_builtin_plugins, resolve_storage

# Load built-in coordinator plugins (sqlite, journalfs, postgresql)
load_builtin_plugins()

logger = logging.getLogger("iterate2")

Expand Down Expand Up @@ -803,14 +806,7 @@ def objective(trial):
directions = ["maximize"] * len(metric_list)
logger.info("Creating Optuna study (directions: %s)", directions)

if args.optuna_db_path.startswith("js:///"):
journal_path = args.optuna_db_path[len("js:///"):]
logger.info("Using JournalStorage at '%s'", journal_path)
storage = JournalStorage(JournalFileStorage(journal_path))
elif "sqlite" in args.optuna_db_path:
storage = args.optuna_db_path
else:
storage = f"sqlite:///{args.optuna_db_path}"
storage = resolve_storage(args.optuna_db_path)
logger.debug("Optuna storage: %s", storage)

study = optuna.create_study(
Expand All @@ -821,10 +817,35 @@ def objective(trial):
)
logger.info("Study '%s' ready (existing trials: %d)", args.optuna_study_name, len(study.trials))

# ── Re-queue failed trials (25 % retry / 75 % new) ────────────────────
failed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL]
n_total = args.optuna_n_trials
if failed_trials:
n_retry = max(1, round(0.25 * n_total))
n_retry = min(n_retry, len(failed_trials)) # can't retry more than we have
n_new = n_total - n_retry
# enqueue the most-recent failed trials first
trials_to_retry = failed_trials[-n_retry:]
logger.info(
"Found %d failed trial(s). Re-queuing %d (25%%) and running %d new (75%%).",
len(failed_trials), n_retry, n_new,
)
for ft in trials_to_retry:
if ft.params: # skip trials that had no params at all
study.enqueue_trial(ft.params)
logger.info(" Enqueued params from failed trial %d: %s", ft.number, ft.params)
else:
logger.info(" Skipped failed trial %d (no params recorded).", ft.number)
# adjust total so we run exactly n_new *additional* new trials on top
n_total = n_new + n_retry # enqueued slots count toward n_trials
else:
logger.info("No failed trials found – running %d fresh trials.", n_total)
# ── end retry logic ───────────────────────────────────────────────────

logger.info("Parallelism: %d worker(s)", args.parallelism)
study.optimize(
objective,
n_trials=args.optuna_n_trials,
n_trials=n_total,
n_jobs=args.parallelism,
catch=(Exception,), # mark trial as FAILED and continue; never crash the study
)
Expand Down
1 change: 1 addition & 0 deletions terratorch_iterate/iterate2/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# terratorch_iterate.iterate2.plugin package
93 changes: 93 additions & 0 deletions terratorch_iterate/iterate2/plugin/coordinator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Lightweight coordinator plugin registry for Optuna storage backends.

Each coordinator plugin lives in its own module inside this package:
- sqlite.py
- journalfs.py
- postgresql.py

Plugins register themselves by calling ``register()`` at import time.
``resolve_storage()`` walks the registry in insertion order and returns the
first matching plugin's storage object.

Usage
-----
>>> from terratorch_iterate.iterate2.plugin.coordinator import resolve_storage
>>> storage = resolve_storage("sqlite:///my_study.db")
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Any

logger = logging.getLogger("iterate2.coordinator")

# ---------------------------------------------------------------------------
# Base class
# ---------------------------------------------------------------------------

class CoordinatorPlugin(ABC):
"""Abstract base for Optuna storage coordinator plugins."""

#: Human-readable name shown in log messages.
name: str = "base"

@abstractmethod
def matches(self, db_path: str) -> bool:
"""Return ``True`` when this plugin should handle *db_path*."""

@abstractmethod
def get_storage(self, db_path: str) -> Any:
"""Return an Optuna-compatible storage object (or URL string) for *db_path*."""


# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------

_registry: list[CoordinatorPlugin] = []


def register(plugin: CoordinatorPlugin) -> None:
"""Register a coordinator plugin. Later registrations take lower priority."""
_registry.append(plugin)
logger.debug("Registered coordinator plugin: %s", plugin.name)


def resolve_storage(db_path: str) -> Any:
"""Walk the registry and return the storage for *db_path*.

Raises
------
ValueError
When no registered plugin matches *db_path*.
"""
for plugin in _registry:
if plugin.matches(db_path):
logger.info("Coordinator plugin '%s' handling db_path '%s'", plugin.name, db_path)
return plugin.get_storage(db_path)
raise ValueError(
f"No coordinator plugin matched db_path={db_path!r}. "
"Make sure the appropriate plugin module is imported before calling resolve_storage()."
)


# ---------------------------------------------------------------------------
# Auto-load built-in plugins
# ---------------------------------------------------------------------------

def load_builtin_plugins() -> None:
"""Import all built-in coordinator plugins so they self-register."""
import importlib
_builtins = [
"terratorch_iterate.iterate2.plugin.coordinator.sqlite",
"terratorch_iterate.iterate2.plugin.coordinator.journalfs",
"terratorch_iterate.iterate2.plugin.coordinator.postgresql",
]
for mod in _builtins:
try:
importlib.import_module(mod)
except ImportError as exc:
logger.warning("Could not load coordinator plugin '%s': %s", mod, exc)
Loading
Loading