diff --git a/configs/gridfm_graphkit_hpo.yaml b/configs/gridfm_graphkit_hpo.yaml index 049466f..d578e19 100644 --- a/configs/gridfm_graphkit_hpo.yaml +++ b/configs/gridfm_graphkit_hpo.yaml @@ -1,7 +1,8 @@ # HPO configuration for gridfm-graphkit HGNS PF case118 # # Hyperparameters: -# gpu_num – number of GPUs to request from the WLM (launcher-level) +# compute – group: selects gpu_num + num_workers + batch_size together +# so that data-loading and batch sizes scale with GPU count # bfloat16 – boolean flag (presence = --bfloat16, absence = flag omitted) # tf32 – boolean flag (presence = --tf32, absence = flag omitted) # compile – torch.compile mode; null disables the flag entirely @@ -14,30 +15,32 @@ # extracted from [performance] lines in trial output metrics: - - case118_ieee/layer_0_residual + - Validation loss - last epoch time - last epoch it/s hpo: - gpu_num: - type: categorical - choices: [1, 2, 4] - - bfloat16: - type: flag # store_true: true → --bfloat16, false → flag omitted - - tf32: - type: flag # store_true: true → --tf32, false → flag omitted + compute: + type: group # one choice co-selects gpu_num, num_workers and batch_size + choices: + small: # single GPU – conservative resources + gpu_num: 1 + num_workers: 32 + batch_size: 64 + medium: # two GPUs – doubled throughput + gpu_num: 2 + num_workers: 16 + batch_size: 32 + large: # four GPUs – full-node run + gpu_num: 4 + num_workers: 8 + batch_size: 16 compile: type: categorical choices: ["max-autotune", "default", "reduce-overhead", null] # null → --compile flag is omitted entirely - num_workers: - type: categorical - choices: [8, 16, 24, 32] - dataset: type: group # one choice selects all bundled args together choices: diff --git a/examples/run_lsf_gridfm_example_postgres.sh b/examples/run_lsf_gridfm_example_postgres.sh new file mode 100755 index 0000000..197998e --- /dev/null +++ b/examples/run_lsf_gridfm_example_postgres.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash +# ============================================================================= +# Example: iterate --wlm lsf with PostgreSQL coordinator for gridfm-graphkit HPO +# +# Each Optuna trial is submitted as an LSF job that looks like: +# +# bsub -gpu "num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB" \ +# -K -o trial.out -e trial.err \ +# -R "rusage[ngpus=1, cpu=16, mem=32GB]" \ +# -J hpo_trial_ \ +# "export PATH='/opt/share/cuda-12.8.1/bin:$PATH' && \ +# export CUDA_HOME='/opt/share/cuda-12.8.1/' && \ +# export LD_LIBRARY_PATH='/opt/share/cuda-12.8.1/lib64:$LD_LIBRARY_PATH' && \ +# cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && \ +# source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && \ +# gridfm_graphkit train " +# +# Prerequisites +# ------------- +# * LSF bsub/bjobs available on PATH +# * gridfm-graphkit installed in the venv below +# * configs/gridfm_graphkit_hpo.yaml present +# * psycopg2-binary installed: pip install 'terratorch-iterate[postgresql]' +# * POSTGRES_URL set (or hard-code it in --optuna-db-path below) +# +# PostgreSQL coordinator +# ---------------------- +# Using PostgreSQL instead of SQLite / JournalFS is the recommended backend for +# high-parallelism HPO on a cluster: multiple bsub jobs can safely write trial +# results concurrently without lock contention. +# +# Set the connection URL as an env-var to avoid embedding credentials in scripts +# that may end up in version control: +# +# export POSTGRES_URL="postgresql://user:password@host:5432/optuna_studies" +# +# or pass it inline: +# +# POSTGRES_URL="postgresql://..." bash run_lsf_gridfm_example_postgres.sh +# ============================================================================= + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# --------------------------------------------------------------------------- +# Required: PostgreSQL connection URL +# --------------------------------------------------------------------------- +: "${POSTGRES_URL:?Please set POSTGRES_URL=postgresql://user:password@host:port/dbname}" + +# --------------------------------------------------------------------------- +# Customisable paths – override via environment variables +# --------------------------------------------------------------------------- +GRIDFM_ROOT="${GRIDFM_ROOT:-/dccstor/terratorch/users/rkie/gitco/gridfm-graphkit}" +GRIDFM_VENV="${GRIDFM_VENV:-/u/rkie/venvs/venv_gridfm-graphkit}" +CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}" +DATA_PATH="${DATA_PATH:-/u/rkie/}" +LOG_DIR="${LOG_DIR:-logs}" + +# --------------------------------------------------------------------------- +# LSF GPU resource string +# Adjust gmodel to the GPU type available on your cluster. +# --------------------------------------------------------------------------- +LSF_GPU_CONFIG="${LSF_GPU_CONFIG:-num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB}" + +# --------------------------------------------------------------------------- +# Pre-run commands executed inside every bsub job before the training script. +# Order matters: +# 1. Export CUDA paths so the GPU driver / toolkit is visible. +# 2. cd into the project root so relative config paths resolve correctly. +# 3. Activate the project venv. +# --------------------------------------------------------------------------- +PRE_RUN="\ +export PATH='${CUDA_BASE}/bin:\$PATH' && \ +export CUDA_HOME='${CUDA_BASE}' && \ +export LD_LIBRARY_PATH='${CUDA_BASE}/lib64:\$LD_LIBRARY_PATH' && \ +cd '${GRIDFM_ROOT}' && \ +source '${GRIDFM_VENV}/bin/activate'" + +# --------------------------------------------------------------------------- +# Static training arguments (not part of the HPO search space). +# These are appended verbatim after the sampled hyperparameters. +# --------------------------------------------------------------------------- +STATIC_ARGS_JSON='{ + "log_dir": "'"${LOG_DIR}"'", + "report-performance": true +}' + +# --------------------------------------------------------------------------- +# Launch iterate +# --------------------------------------------------------------------------- +iterate \ + --script "gridfm_graphkit train" \ + --interpreter "" \ + --root-dir "${GRIDFM_ROOT}" \ + --wlm lsf \ + --pre-run-commands "${PRE_RUN}" \ + --no-underscore-to-hyphen \ + --gpu-count 1 \ + --cpu-count 16 \ + --mem-gb 32 \ + #--lsf-gpu-config-string "${LSF_GPU_CONFIG}" \ + --optuna-study-name gridfm_lsf_postgres_hpo \ + --optuna-db-path "${POSTGRES_URL}" \ + --parallelism 4 \ + --optuna-n-trials 20 \ + --hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" \ + --static-args-json "${STATIC_ARGS_JSON}" diff --git a/pyproject.toml b/pyproject.toml index ddd00c0..d53e340 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,12 @@ utility = [ # If you want to catch Nvidia GPU metrics, you also need to install pynvml: nvidia = ["pynvml"] +# PostgreSQL coordinator plugin — installs the psycopg2 driver. +# Use psycopg2-binary for a self-contained wheel (no libpq build dependency). +# For production deployments that compile against a system libpq, replace with +# psycopg2 (without -binary). +postgresql = ["psycopg2-binary>=2.9"] + # If you are using AMD/HIP GPUs, install pyrsmi amd = ["pyrsmi"] diff --git a/terratorch_iterate/iterate2/__init__.py b/terratorch_iterate/iterate2/__init__.py new file mode 100644 index 0000000..8c5e851 --- /dev/null +++ b/terratorch_iterate/iterate2/__init__.py @@ -0,0 +1,4 @@ +# terratorch_iterate.iterate2 package +# Re-export main so that `from terratorch_iterate.iterate2 import main` keeps +# working after iterate2.py was turned into a package directory. +from terratorch_iterate.iterate2._iterate2 import main # noqa: F401 diff --git a/terratorch_iterate/iterate2.py b/terratorch_iterate/iterate2/_iterate2.py similarity index 95% rename from terratorch_iterate/iterate2.py rename to terratorch_iterate/iterate2/_iterate2.py index a22ad94..e34e3ed 100644 --- a/terratorch_iterate/iterate2.py +++ b/terratorch_iterate/iterate2/_iterate2.py @@ -14,8 +14,11 @@ from typing import Dict, Any, Optional, Literal, List import optuna -from optuna.storages import JournalStorage, JournalFileStorage import yaml +from terratorch_iterate.iterate2.plugin.coordinator import load_builtin_plugins, resolve_storage + +# Load built-in coordinator plugins (sqlite, journalfs, postgresql) +load_builtin_plugins() logger = logging.getLogger("iterate2") @@ -803,14 +806,7 @@ def objective(trial): directions = ["maximize"] * len(metric_list) logger.info("Creating Optuna study (directions: %s)", directions) - if args.optuna_db_path.startswith("js:///"): - journal_path = args.optuna_db_path[len("js:///"):] - logger.info("Using JournalStorage at '%s'", journal_path) - storage = JournalStorage(JournalFileStorage(journal_path)) - elif "sqlite" in args.optuna_db_path: - storage = args.optuna_db_path - else: - storage = f"sqlite:///{args.optuna_db_path}" + storage = resolve_storage(args.optuna_db_path) logger.debug("Optuna storage: %s", storage) study = optuna.create_study( @@ -821,10 +817,35 @@ def objective(trial): ) logger.info("Study '%s' ready (existing trials: %d)", args.optuna_study_name, len(study.trials)) + # ── Re-queue failed trials (25 % retry / 75 % new) ──────────────────── + failed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] + n_total = args.optuna_n_trials + if failed_trials: + n_retry = max(1, round(0.25 * n_total)) + n_retry = min(n_retry, len(failed_trials)) # can't retry more than we have + n_new = n_total - n_retry + # enqueue the most-recent failed trials first + trials_to_retry = failed_trials[-n_retry:] + logger.info( + "Found %d failed trial(s). Re-queuing %d (25%%) and running %d new (75%%).", + len(failed_trials), n_retry, n_new, + ) + for ft in trials_to_retry: + if ft.params: # skip trials that had no params at all + study.enqueue_trial(ft.params) + logger.info(" Enqueued params from failed trial %d: %s", ft.number, ft.params) + else: + logger.info(" Skipped failed trial %d (no params recorded).", ft.number) + # adjust total so we run exactly n_new *additional* new trials on top + n_total = n_new + n_retry # enqueued slots count toward n_trials + else: + logger.info("No failed trials found – running %d fresh trials.", n_total) + # ── end retry logic ─────────────────────────────────────────────────── + logger.info("Parallelism: %d worker(s)", args.parallelism) study.optimize( objective, - n_trials=args.optuna_n_trials, + n_trials=n_total, n_jobs=args.parallelism, catch=(Exception,), # mark trial as FAILED and continue; never crash the study ) diff --git a/terratorch_iterate/iterate2/plugin/__init__.py b/terratorch_iterate/iterate2/plugin/__init__.py new file mode 100644 index 0000000..f89e362 --- /dev/null +++ b/terratorch_iterate/iterate2/plugin/__init__.py @@ -0,0 +1 @@ +# terratorch_iterate.iterate2.plugin package diff --git a/terratorch_iterate/iterate2/plugin/coordinator/__init__.py b/terratorch_iterate/iterate2/plugin/coordinator/__init__.py new file mode 100644 index 0000000..fd6459a --- /dev/null +++ b/terratorch_iterate/iterate2/plugin/coordinator/__init__.py @@ -0,0 +1,93 @@ +""" +Lightweight coordinator plugin registry for Optuna storage backends. + +Each coordinator plugin lives in its own module inside this package: + - sqlite.py + - journalfs.py + - postgresql.py + +Plugins register themselves by calling ``register()`` at import time. +``resolve_storage()`` walks the registry in insertion order and returns the +first matching plugin's storage object. + +Usage +----- +>>> from terratorch_iterate.iterate2.plugin.coordinator import resolve_storage +>>> storage = resolve_storage("sqlite:///my_study.db") +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger("iterate2.coordinator") + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + +class CoordinatorPlugin(ABC): + """Abstract base for Optuna storage coordinator plugins.""" + + #: Human-readable name shown in log messages. + name: str = "base" + + @abstractmethod + def matches(self, db_path: str) -> bool: + """Return ``True`` when this plugin should handle *db_path*.""" + + @abstractmethod + def get_storage(self, db_path: str) -> Any: + """Return an Optuna-compatible storage object (or URL string) for *db_path*.""" + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +_registry: list[CoordinatorPlugin] = [] + + +def register(plugin: CoordinatorPlugin) -> None: + """Register a coordinator plugin. Later registrations take lower priority.""" + _registry.append(plugin) + logger.debug("Registered coordinator plugin: %s", plugin.name) + + +def resolve_storage(db_path: str) -> Any: + """Walk the registry and return the storage for *db_path*. + + Raises + ------ + ValueError + When no registered plugin matches *db_path*. + """ + for plugin in _registry: + if plugin.matches(db_path): + logger.info("Coordinator plugin '%s' handling db_path '%s'", plugin.name, db_path) + return plugin.get_storage(db_path) + raise ValueError( + f"No coordinator plugin matched db_path={db_path!r}. " + "Make sure the appropriate plugin module is imported before calling resolve_storage()." + ) + + +# --------------------------------------------------------------------------- +# Auto-load built-in plugins +# --------------------------------------------------------------------------- + +def load_builtin_plugins() -> None: + """Import all built-in coordinator plugins so they self-register.""" + import importlib + _builtins = [ + "terratorch_iterate.iterate2.plugin.coordinator.sqlite", + "terratorch_iterate.iterate2.plugin.coordinator.journalfs", + "terratorch_iterate.iterate2.plugin.coordinator.postgresql", + ] + for mod in _builtins: + try: + importlib.import_module(mod) + except ImportError as exc: + logger.warning("Could not load coordinator plugin '%s': %s", mod, exc) diff --git a/terratorch_iterate/iterate2/plugin/coordinator/journalfs.py b/terratorch_iterate/iterate2/plugin/coordinator/journalfs.py new file mode 100644 index 0000000..633afff --- /dev/null +++ b/terratorch_iterate/iterate2/plugin/coordinator/journalfs.py @@ -0,0 +1,56 @@ +""" +JournalFS coordinator plugin for iterate2. + +Matches any db_path that starts with the ``js:///`` prefix. + +The prefix is stripped and the remainder is treated as a local filesystem +path. The storage object returned is an Optuna +``JournalStorage(JournalFileBackend(...))`` instance which is safe for +concurrent, multi-process access and does not require a database server. + +Self-registers at import time. +""" + +from __future__ import annotations + +import logging + +from optuna.storages import JournalStorage + +# Prefer the non-deprecated JournalFileBackend (Optuna ≥4.0); fall back to the +# legacy JournalFileStorage for older installations. +try: + from optuna.storages.journal import JournalFileBackend as _JournalFileBackend # type: ignore + _USE_BACKEND = True +except ImportError: + from optuna.storages import JournalFileStorage as _JournalFileBackend # type: ignore + _USE_BACKEND = False + +from terratorch_iterate.iterate2.plugin.coordinator import CoordinatorPlugin, register + +logger = logging.getLogger("iterate2.coordinator.journalfs") + +_PREFIX = "js:///" + +# --------------------------------------------------------------------------- +# Plugin implementation +# --------------------------------------------------------------------------- + +class JournalFSCoordinator(CoordinatorPlugin): + name = "journalfs" + + def matches(self, db_path: str) -> bool: + return db_path.startswith(_PREFIX) + + def get_storage(self, db_path: str) -> JournalStorage: + journal_path = db_path[len(_PREFIX):] + backend_cls = "JournalFileBackend" if _USE_BACKEND else "JournalFileStorage" + logger.info("JournalStorage backend=%s path=%s", backend_cls, journal_path) + return JournalStorage(_JournalFileBackend(journal_path)) + + +# --------------------------------------------------------------------------- +# Auto-register +# --------------------------------------------------------------------------- + +register(JournalFSCoordinator()) diff --git a/terratorch_iterate/iterate2/plugin/coordinator/postgresql.py b/terratorch_iterate/iterate2/plugin/coordinator/postgresql.py new file mode 100644 index 0000000..825477a --- /dev/null +++ b/terratorch_iterate/iterate2/plugin/coordinator/postgresql.py @@ -0,0 +1,117 @@ +""" +PostgreSQL coordinator plugin for iterate2. + +Matches any db_path that starts with ``postgresql://`` or ``postgres://``. + +The raw URL is passed directly to Optuna's RDB storage layer (SQLAlchemy +under the hood). Requires the ``psycopg2`` (or ``psycopg2-binary``) package +to be installed in the active environment:: + + pip install psycopg2-binary + +Example db_path values +----------------------- +``postgresql://user:password@localhost:5432/optuna_studies`` +``postgres://user:password@db-host/mydb`` + +Self-registers at import time. +""" + +from __future__ import annotations + +import logging +import re + +from terratorch_iterate.iterate2.plugin.coordinator import CoordinatorPlugin, register + +logger = logging.getLogger("iterate2.coordinator.postgresql") + +_SCHEMES = ("postgresql://", "postgres://") + +# Default connect_timeout (seconds) injected into every connection so that +# cloud databases with firewalled ports don't cause silent hangs. +_DEFAULT_CONNECT_TIMEOUT = 30 + +# --------------------------------------------------------------------------- +# Plugin implementation +# --------------------------------------------------------------------------- + +class PostgreSQLCoordinator(CoordinatorPlugin): + name = "postgresql" + + def matches(self, db_path: str) -> bool: + return any(db_path.startswith(scheme) for scheme in _SCHEMES) + + def get_storage(self, db_path: str): + """Return an ``optuna.storages.RDBStorage`` configured for *db_path*. + + The storage object is used (rather than a bare URL string) so that we + can inject ``connect_args`` (e.g. ``connect_timeout``, ``sslmode``) + without requiring the caller to embed those options in the URL. + + SSL + --- + If ``sslmode`` is not already present in the URL query string *and* the + host is not ``localhost`` / ``127.0.0.1``, ``sslmode=require`` is added + automatically. Pass ``?sslmode=disable`` in the URL to suppress this. + """ + # Normalise legacy "postgres://" → "postgresql://" because SQLAlchemy + # 1.4+ dropped support for the short-form scheme. + if db_path.startswith("postgres://") and not db_path.startswith("postgresql://"): + db_path = "postgresql://" + db_path[len("postgres://"):] + logger.debug("Normalised scheme to: %s", db_path) + + # Verify psycopg2 is available early so the error is clear. + try: + import psycopg2 # noqa: F401 + except ImportError: + raise ImportError( + "psycopg2 is not installed but is required for PostgreSQL storage.\n\n" + "Install options:\n" + " # recommended – pre-built wheel, no compiler needed:\n" + " pip install psycopg2-binary\n\n" + " # or via the project's postgresql extra:\n" + " pip install 'terratorch-iterate[postgresql]'\n\n" + " # production deployments that compile against a system libpq:\n" + " pip install psycopg2\n" + ) from None + + connect_args: dict = {"connect_timeout": _DEFAULT_CONNECT_TIMEOUT} + + # Auto-enable SSL for non-local hosts when sslmode not already set. + if "sslmode" not in db_path: + host = _extract_host(db_path) + if host not in ("localhost", "127.0.0.1", "::1", ""): + connect_args["sslmode"] = "require" + logger.debug("Auto-enabled sslmode=require for host '%s'", host) + + logger.info("PostgreSQL storage URL: %s", _redact(db_path)) + + from optuna.storages import RDBStorage + return RDBStorage( + url=db_path, + engine_kwargs={"connect_args": connect_args}, + ) + + +def _extract_host(url: str) -> str: + """Return the hostname portion of a postgresql:// URL.""" + try: + # url looks like postgresql://user:pass@host:port/db + after_at = url.split("@", 1)[1] + host_port = after_at.split("/")[0] + return host_port.split(":")[0] + except (IndexError, AttributeError): + return "" + + +def _redact(url: str) -> str: + """Replace the password in a DB URL with '***' for safe logging.""" + return re.sub(r"(://[^:]+:)[^@]+(@)", r"\1***\2", url) + + +# --------------------------------------------------------------------------- +# Auto-register +# --------------------------------------------------------------------------- + +register(PostgreSQLCoordinator()) diff --git a/terratorch_iterate/iterate2/plugin/coordinator/sqlite.py b/terratorch_iterate/iterate2/plugin/coordinator/sqlite.py new file mode 100644 index 0000000..b5cc38b --- /dev/null +++ b/terratorch_iterate/iterate2/plugin/coordinator/sqlite.py @@ -0,0 +1,51 @@ +""" +SQLite coordinator plugin for iterate2. + +Matches any db_path that: + - already contains the ``sqlite:///`` scheme, OR + - ends with ``.db`` or ``.sqlite``, OR + - contains the substring ``sqlite`` + +The storage value returned to Optuna is always a fully-qualified +``sqlite:///...`` URL so that SQLAlchemy can open it correctly. + +Self-registers at import time. +""" + +from __future__ import annotations + +import logging + +from terratorch_iterate.iterate2.plugin.coordinator import CoordinatorPlugin, register + +logger = logging.getLogger("iterate2.coordinator.sqlite") + +# --------------------------------------------------------------------------- +# Plugin implementation +# --------------------------------------------------------------------------- + +class SQLiteCoordinator(CoordinatorPlugin): + name = "sqlite" + + def matches(self, db_path: str) -> bool: + return ( + db_path.startswith("sqlite:///") + or db_path.endswith(".db") + or db_path.endswith(".sqlite") + or "sqlite" in db_path + ) + + def get_storage(self, db_path: str) -> str: + if db_path.startswith("sqlite:///"): + storage_url = db_path + else: + storage_url = f"sqlite:///{db_path}" + logger.info("SQLite storage URL: %s", storage_url) + return storage_url + + +# --------------------------------------------------------------------------- +# Auto-register +# --------------------------------------------------------------------------- + +register(SQLiteCoordinator()) diff --git a/tests/integration/test_coordinator_plugins.py b/tests/integration/test_coordinator_plugins.py new file mode 100644 index 0000000..2cdc54f --- /dev/null +++ b/tests/integration/test_coordinator_plugins.py @@ -0,0 +1,308 @@ +""" +Integration tests for the iterate2 coordinator plugin system. + +SQLite – always runs (uses a temp file). +JournalFS – always runs (uses a temp file). +PostgreSQL – skipped unless the environment variable POSTGRES_URL is set, e.g.: + + export POSTGRES_URL="postgresql://user:password@localhost:5432/optuna_test" + pytest tests/integration/test_coordinator_plugins.py -v + +The PostgreSQL test creates a study, adds a dummy trial, then removes the study +so it leaves no permanent state in the database. +""" + +from __future__ import annotations + +import os +import tempfile +from typing import Any + +import optuna +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Silence Optuna's own INFO logs during tests so pytest output stays clean. +optuna.logging.set_verbosity(optuna.logging.WARNING) + + +def _fresh_registry(): + """Return a clean load_builtin_plugins / resolve_storage pair backed by an + isolated registry so tests cannot leak state into each other.""" + # Re-import the coordinator package with a private registry copy. + from terratorch_iterate.iterate2.plugin import coordinator as coord_pkg + import importlib, types + + # Build a fresh module clone with its own empty registry. + fresh = types.ModuleType(coord_pkg.__name__ + "._test_clone") + fresh.__dict__.update({k: v for k, v in coord_pkg.__dict__.items() + if k not in ("_registry",)}) + fresh._registry = [] + + def _register(plugin): + fresh._registry.append(plugin) + + def _resolve(db_path): + for plugin in fresh._registry: + if plugin.matches(db_path): + return plugin.get_storage(db_path) + raise ValueError(f"No coordinator plugin matched db_path={db_path!r}") + + fresh.register = _register + fresh.resolve_storage = _resolve + return fresh + + +def _create_and_verify_study(storage: Any, study_name: str) -> None: + """Create a one-trial study against *storage*, assert it persists.""" + study = optuna.create_study(study_name=study_name, storage=storage, + load_if_exists=True) + + def objective(trial): + x = trial.suggest_float("x", -1.0, 1.0) + return x ** 2 + + study.optimize(objective, n_trials=1) + assert len(study.trials) == 1, "Expected exactly 1 completed trial" + assert study.trials[0].state == optuna.trial.TrialState.COMPLETE + + +# --------------------------------------------------------------------------- +# SQLite plugin +# --------------------------------------------------------------------------- + +class TestSQLiteCoordinator: + def _make_storage(self, db_url: str): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + return SQLiteCoordinator().get_storage(db_url) + + def test_matches_sqlite_scheme(self): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + p = SQLiteCoordinator() + assert p.matches("sqlite:///foo.db") + + def test_matches_dot_db_extension(self): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + p = SQLiteCoordinator() + assert p.matches("/tmp/my_study.db") + + def test_matches_dot_sqlite_extension(self): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + p = SQLiteCoordinator() + assert p.matches("/tmp/my_study.sqlite") + + def test_no_match_journalfs(self): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + p = SQLiteCoordinator() + assert not p.matches("js:///tmp/journal.log") + + def test_normalises_plain_path_to_sqlite_url(self): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + url = SQLiteCoordinator().get_storage("/tmp/study.db") + assert url.startswith("sqlite:///") + + def test_passthrough_existing_sqlite_url(self): + from terratorch_iterate.iterate2.plugin.coordinator.sqlite import SQLiteCoordinator + url = "sqlite:///existing.db" + assert SQLiteCoordinator().get_storage(url) == url + + def test_full_study_lifecycle(self, tmp_path): + db_file = tmp_path / "test_study.db" + storage = self._make_storage(str(db_file)) + _create_and_verify_study(storage, "sqlite_integration_test") + + def test_resolve_storage_via_registry(self, tmp_path): + """End-to-end: resolve_storage() picks the SQLite plugin.""" + from terratorch_iterate.iterate2.plugin.coordinator import ( + load_builtin_plugins, resolve_storage, + ) + load_builtin_plugins() + db_file = tmp_path / "registry_test.db" + storage = resolve_storage(str(db_file)) + assert storage.startswith("sqlite:///") + + +# --------------------------------------------------------------------------- +# JournalFS plugin +# --------------------------------------------------------------------------- + +class TestJournalFSCoordinator: + def _make_storage(self, journal_path: str): + from terratorch_iterate.iterate2.plugin.coordinator.journalfs import JournalFSCoordinator + return JournalFSCoordinator().get_storage(f"js:///{journal_path}") + + def test_matches_js_prefix(self): + from terratorch_iterate.iterate2.plugin.coordinator.journalfs import JournalFSCoordinator + assert JournalFSCoordinator().matches("js:///tmp/j.log") + + def test_no_match_sqlite(self): + from terratorch_iterate.iterate2.plugin.coordinator.journalfs import JournalFSCoordinator + assert not JournalFSCoordinator().matches("sqlite:///foo.db") + + def test_no_match_postgresql(self): + from terratorch_iterate.iterate2.plugin.coordinator.journalfs import JournalFSCoordinator + assert not JournalFSCoordinator().matches("postgresql://u:p@h/db") + + def test_returns_journal_storage_object(self, tmp_path): + from optuna.storages import JournalStorage + journal_file = tmp_path / "test.log" + storage = self._make_storage(str(journal_file)) + assert isinstance(storage, JournalStorage) + + def test_full_study_lifecycle(self, tmp_path): + journal_file = tmp_path / "study.log" + storage = self._make_storage(str(journal_file)) + _create_and_verify_study(storage, "journalfs_integration_test") + assert journal_file.exists(), "Journal file should exist after the study" + + def test_concurrent_writers(self, tmp_path): + """Two studies sharing the same journal file must not corrupt each other.""" + import threading + from optuna.storages import JournalStorage + journal_file = tmp_path / "shared.log" + storage_a = self._make_storage(str(journal_file)) + storage_b = self._make_storage(str(journal_file)) + + errors: list[Exception] = [] + + def run(storage, name): + try: + study = optuna.create_study(study_name=name, storage=storage, + load_if_exists=True) + study.optimize(lambda t: t.suggest_float("x", 0, 1), n_trials=3) + except Exception as exc: + errors.append(exc) + + t1 = threading.Thread(target=run, args=(storage_a, "worker_a")) + t2 = threading.Thread(target=run, args=(storage_b, "worker_b")) + t1.start(); t2.start() + t1.join(); t2.join() + + assert not errors, f"Concurrent writers raised: {errors}" + + def test_resolve_storage_via_registry(self, tmp_path): + from terratorch_iterate.iterate2.plugin.coordinator import ( + load_builtin_plugins, resolve_storage, + ) + load_builtin_plugins() + journal_file = tmp_path / "registry.log" + storage = resolve_storage(f"js:///{journal_file}") + from optuna.storages import JournalStorage + assert isinstance(storage, JournalStorage) + + +# --------------------------------------------------------------------------- +# PostgreSQL plugin +# --------------------------------------------------------------------------- + +POSTGRES_URL = os.environ.get("POSTGRES_URL", "") + +postgres_required = pytest.mark.skipif( + not POSTGRES_URL, + reason=( + "Set POSTGRES_URL=postgresql://user:pass@host:5432/dbname " + "to run PostgreSQL coordinator tests" + ), +) + + +class TestPostgreSQLCoordinator: + def test_matches_postgresql_scheme(self): + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + assert PostgreSQLCoordinator().matches("postgresql://u:p@h/db") + + def test_matches_legacy_postgres_scheme(self): + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + assert PostgreSQLCoordinator().matches("postgres://u:p@h/db") + + def test_no_match_sqlite(self): + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + assert not PostgreSQLCoordinator().matches("sqlite:///foo.db") + + def test_no_match_journalfs(self): + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + assert not PostgreSQLCoordinator().matches("js:///foo.log") + + def test_normalises_legacy_scheme(self): + """'postgres://' must be normalised to 'postgresql://' before it reaches SQLAlchemy.""" + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import _extract_host + # Test the helper that parses the host out of the normalised URL. + host = _extract_host("postgresql://user:pass@my-host.example.com:5432/db") + assert host == "my-host.example.com" + # Test that the legacy postgres:// scheme gets normalised (string level, no DB needed). + legacy = "postgres://user:pass@host/db" + normalised = "postgresql://" + legacy[len("postgres://"):] + assert normalised.startswith("postgresql://") + + def test_missing_psycopg2_raises_import_error(self, monkeypatch): + """If psycopg2 is absent the plugin must raise a clear ImportError.""" + import builtins + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "psycopg2": + raise ImportError("mocked missing psycopg2") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + with pytest.raises(ImportError, match="psycopg2"): + PostgreSQLCoordinator().get_storage("postgresql://u:p@h/db") + + @postgres_required + def test_full_study_lifecycle(self): + import uuid + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + storage = PostgreSQLCoordinator().get_storage(POSTGRES_URL) + study_name = f"pg_integration_{uuid.uuid4().hex[:8]}" + try: + _create_and_verify_study(storage, study_name) + finally: + # Clean up: delete the study so the DB stays tidy. + try: + optuna.delete_study(study_name=study_name, storage=storage) + except Exception: + pass + + @postgres_required + def test_resolve_storage_via_registry(self): + from optuna.storages import RDBStorage + from terratorch_iterate.iterate2.plugin.coordinator import ( + load_builtin_plugins, resolve_storage, + ) + load_builtin_plugins() + storage = resolve_storage(POSTGRES_URL) + assert isinstance(storage, RDBStorage) + + @postgres_required + def test_parallel_trials(self): + """Multiple threads sharing a PostgreSQL study must all complete cleanly.""" + import threading, uuid + from terratorch_iterate.iterate2.plugin.coordinator.postgresql import PostgreSQLCoordinator + storage = PostgreSQLCoordinator().get_storage(POSTGRES_URL) + study_name = f"pg_parallel_{uuid.uuid4().hex[:8]}" + study = optuna.create_study(study_name=study_name, storage=storage, + load_if_exists=True) + errors: list[Exception] = [] + + def worker(): + try: + study.optimize(lambda t: t.suggest_float("x", 0, 1), n_trials=2) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + try: + optuna.delete_study(study_name=study_name, storage=storage) + except Exception: + pass + + assert not errors, f"Parallel trials raised: {errors}"