From 223c66f8b048cf8fc49655979b58b1b3694f445a Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 29 Oct 2024 09:11:40 +0000 Subject: [PATCH] Initial commit to provide failover - Wraps experiment to capture connection errors - Logs to disk if server not found --- .../training/diagnostics/mlflow/logger.py | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 183e7a0d..f153ff9e 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -10,6 +10,7 @@ from __future__ import annotations +import functools import io import logging import os @@ -23,7 +24,10 @@ from typing import Literal from weakref import WeakValueDictionary +import mlflow +from mlflow import MlflowClient from packaging.version import Version +from pytorch_lightning.loggers.mlflow import LOCAL_FILE_URI_PREFIX from pytorch_lightning.loggers.mlflow import MLFlowLogger from pytorch_lightning.loggers.mlflow import _convert_params from pytorch_lightning.loggers.mlflow import _flatten_dict @@ -36,7 +40,6 @@ if TYPE_CHECKING: from argparse import Namespace - import mlflow LOGGER = logging.getLogger(__name__) @@ -244,6 +247,8 @@ def _remove_csi(line: bytes) -> bytes: class AnemoiMLflowLogger(MLFlowLogger): """A custom MLflow logger that logs terminal output.""" + _failed_to_offline = False + def __init__( self, experiment_name: str = "lightning_logs", @@ -323,6 +328,7 @@ def __init__( tracking_uri=tracking_uri, on_resume_create_child=on_resume_create_child, ) + self._save_dir = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" # Before creating the run we need to overwrite the tracking_uri and save_dir if offline if offline: # OFFLINE - When we run offline we can pass a save_dir pointing to a local path @@ -426,7 +432,71 @@ def _get_mlflow_run_params( def experiment(self) -> MLFlowLogger.experiment: if rank_zero_only.rank == 0: self.auth.authenticate() - return super().experiment + + parent_obj = super() + logger_obj = self + experiment = super().experiment + + if self._failed_to_offline: + self.return_to_online() + + class WrappedExperiment: + """Wrap the experiment object to handle connection errors. + + Fails over to offline logging if an error occurs. + """ + + def __getattr__(self, key: str) -> Any: + # Only wrap logging calls + if not hasattr(experiment, key) or not key.startswith("log"): + return super().__getattr__(key) + if not callable(getattr(experiment, key)): + return getattr(experiment, key) + + @functools.wraps(getattr(experiment, key)) + def wrapped_method(*args, **kwargs) -> Any: + try: + return getattr(experiment, key)(*args, **kwargs) + except mlflow.MlflowException as e: + LOGGER.warning("An error occurred when calling %s: %s", key, e) + logger_obj.failover_to_offline() + global experiment + experiment = parent_obj.experiment + return getattr(experiment, key)(*args, **kwargs) + + return wrapped_method + + return WrappedExperiment + + def failover_to_offline(self) -> None: + """Failover to offline mode.""" + self._initialized = False + self._mlflow_client = MlflowClient(tracking_uri=self._save_dir) + self._failed_to_offline = True + + def return_to_online(self) -> None: + """Return to online mode.""" + try: + health_check(self._tracking_uri) + except ConnectionError: + return + + self._initialized = False + self._mlflow_client = MlflowClient(tracking_uri=self._tracking_uri) + self._failed_to_offline = False + + from anemoi.training.utils.mlflow_sync import MlFlowSync + + try: + MlFlowSync( + self._save_dir, + self._tracking_uri, + self._run_id, + self._experiment_name, + ).sync() + except Exception as e: # noqa: BLE001 + LOGGER.warning("Failed to sync to online server: %s", e) + self.failover_to_offline() @rank_zero_only def log_system_metrics(self) -> None: