-
Notifications
You must be signed in to change notification settings - Fork 15
MLFlow Offline Failover #114
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
import functools | ||
import io | ||
import logging | ||
import os | ||
|
@@ -23,7 +24,10 @@ | |
from typing import Literal | ||
from weakref import WeakValueDictionary | ||
|
||
import mlflow | ||
from mlflow import MlflowClient | ||
from packaging.version import Version | ||
from pytorch_lightning.loggers.mlflow import LOCAL_FILE_URI_PREFIX | ||
from pytorch_lightning.loggers.mlflow import MLFlowLogger | ||
from pytorch_lightning.loggers.mlflow import _convert_params | ||
from pytorch_lightning.loggers.mlflow import _flatten_dict | ||
|
@@ -36,7 +40,6 @@ | |
if TYPE_CHECKING: | ||
from argparse import Namespace | ||
|
||
import mlflow | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
@@ -244,6 +247,8 @@ def _remove_csi(line: bytes) -> bytes: | |
class AnemoiMLflowLogger(MLFlowLogger): | ||
"""A custom MLflow logger that logs terminal output.""" | ||
|
||
_failed_to_offline = False | ||
|
||
def __init__( | ||
self, | ||
experiment_name: str = "lightning_logs", | ||
|
@@ -323,6 +328,7 @@ def __init__( | |
tracking_uri=tracking_uri, | ||
on_resume_create_child=on_resume_create_child, | ||
) | ||
self._save_dir = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" | ||
# Before creating the run we need to overwrite the tracking_uri and save_dir if offline | ||
if offline: | ||
# OFFLINE - When we run offline we can pass a save_dir pointing to a local path | ||
|
@@ -426,7 +432,71 @@ def _get_mlflow_run_params( | |
def experiment(self) -> MLFlowLogger.experiment: | ||
if rank_zero_only.rank == 0: | ||
self.auth.authenticate() | ||
return super().experiment | ||
|
||
parent_obj = super() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just probably me not being familiar with this, could you explain what does |
||
logger_obj = self | ||
experiment = super().experiment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the 'experiment' method does not return an experiment object rather the mlflow client https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/loggers/mlflow.html#MLFlowLogger so what's the idea behind that |
||
|
||
if self._failed_to_offline: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if the use is already running in offline mode? That's needed in HPCs like Leonardo and MN5 |
||
self.return_to_online() | ||
|
||
class WrappedExperiment: | ||
"""Wrap the experiment object to handle connection errors. | ||
|
||
Fails over to offline logging if an error occurs. | ||
""" | ||
|
||
def __getattr__(self, key: str) -> Any: | ||
# Only wrap logging calls | ||
if not hasattr(experiment, key) or not key.startswith("log"): | ||
return super().__getattr__(key) | ||
if not callable(getattr(experiment, key)): | ||
return getattr(experiment, key) | ||
|
||
@functools.wraps(getattr(experiment, key)) | ||
def wrapped_method(*args, **kwargs) -> Any: | ||
try: | ||
return getattr(experiment, key)(*args, **kwargs) | ||
except mlflow.MlflowException as e: | ||
LOGGER.warning("An error occurred when calling %s: %s", key, e) | ||
logger_obj.failover_to_offline() | ||
global experiment | ||
experiment = parent_obj.experiment | ||
return getattr(experiment, key)(*args, **kwargs) | ||
|
||
return wrapped_method | ||
|
||
return WrappedExperiment | ||
|
||
def failover_to_offline(self) -> None: | ||
"""Failover to offline mode.""" | ||
self._initialized = False | ||
self._mlflow_client = MlflowClient(tracking_uri=self._save_dir) | ||
self._failed_to_offline = True | ||
|
||
def return_to_online(self) -> None: | ||
"""Return to online mode.""" | ||
try: | ||
health_check(self._tracking_uri) | ||
except ConnectionError: | ||
return | ||
|
||
self._initialized = False | ||
self._mlflow_client = MlflowClient(tracking_uri=self._tracking_uri) | ||
self._failed_to_offline = False | ||
|
||
from anemoi.training.utils.mlflow_sync import MlFlowSync | ||
|
||
try: | ||
MlFlowSync( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would try to sync the offline part of the run before logging again online? We'd probably need to use the |
||
self._save_dir, | ||
self._tracking_uri, | ||
self._run_id, | ||
self._experiment_name, | ||
).sync() | ||
except Exception as e: # noqa: BLE001 | ||
LOGGER.warning("Failed to sync to online server: %s", e) | ||
self.failover_to_offline() | ||
|
||
@rank_zero_only | ||
def log_system_metrics(self) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the
LOCAL_FILE_URI_PREFIX
doing here?