Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

MLFlow Offline Failover #114

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import annotations

import functools
import io
import logging
import os
Expand All @@ -23,7 +24,10 @@
from typing import Literal
from weakref import WeakValueDictionary

import mlflow
from mlflow import MlflowClient
from packaging.version import Version
from pytorch_lightning.loggers.mlflow import LOCAL_FILE_URI_PREFIX
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from pytorch_lightning.loggers.mlflow import _convert_params
from pytorch_lightning.loggers.mlflow import _flatten_dict
Expand All @@ -36,7 +40,6 @@
if TYPE_CHECKING:
from argparse import Namespace

import mlflow

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -244,6 +247,8 @@ def _remove_csi(line: bytes) -> bytes:
class AnemoiMLflowLogger(MLFlowLogger):
"""A custom MLflow logger that logs terminal output."""

_failed_to_offline = False

def __init__(
self,
experiment_name: str = "lightning_logs",
Expand Down Expand Up @@ -323,6 +328,7 @@ def __init__(
tracking_uri=tracking_uri,
on_resume_create_child=on_resume_create_child,
)
self._save_dir = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the LOCAL_FILE_URI_PREFIX doing here?

# Before creating the run we need to overwrite the tracking_uri and save_dir if offline
if offline:
# OFFLINE - When we run offline we can pass a save_dir pointing to a local path
Expand Down Expand Up @@ -426,7 +432,71 @@ def _get_mlflow_run_params(
def experiment(self) -> MLFlowLogger.experiment:
if rank_zero_only.rank == 0:
self.auth.authenticate()
return super().experiment

parent_obj = super()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just probably me not being familiar with this, could you explain what does parent_obj = super() do here?

logger_obj = self
experiment = super().experiment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the 'experiment' method does not return an experiment object rather the mlflow client https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/loggers/mlflow.html#MLFlowLogger so what's the idea behind that experiment=super().experiment ? This new experiment method won't be returning the same as the original so that could introduce some problems?


if self._failed_to_offline:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the use is already running in offline mode? That's needed in HPCs like Leonardo and MN5

self.return_to_online()

class WrappedExperiment:
"""Wrap the experiment object to handle connection errors.

Fails over to offline logging if an error occurs.
"""

def __getattr__(self, key: str) -> Any:
# Only wrap logging calls
if not hasattr(experiment, key) or not key.startswith("log"):
return super().__getattr__(key)
if not callable(getattr(experiment, key)):
return getattr(experiment, key)

@functools.wraps(getattr(experiment, key))
def wrapped_method(*args, **kwargs) -> Any:
try:
return getattr(experiment, key)(*args, **kwargs)
except mlflow.MlflowException as e:
LOGGER.warning("An error occurred when calling %s: %s", key, e)
logger_obj.failover_to_offline()
global experiment
experiment = parent_obj.experiment
return getattr(experiment, key)(*args, **kwargs)

return wrapped_method

return WrappedExperiment

def failover_to_offline(self) -> None:
"""Failover to offline mode."""
self._initialized = False
self._mlflow_client = MlflowClient(tracking_uri=self._save_dir)
self._failed_to_offline = True

def return_to_online(self) -> None:
"""Return to online mode."""
try:
health_check(self._tracking_uri)
except ConnectionError:
return

self._initialized = False
self._mlflow_client = MlflowClient(tracking_uri=self._tracking_uri)
self._failed_to_offline = False

from anemoi.training.utils.mlflow_sync import MlFlowSync

try:
MlFlowSync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would try to sync the offline part of the run before logging again online? We'd probably need to use the self._run_id and recall _get_mlflow_run_params so it restarts from the synced run it

self._save_dir,
self._tracking_uri,
self._run_id,
self._experiment_name,
).sync()
except Exception as e: # noqa: BLE001
LOGGER.warning("Failed to sync to online server: %s", e)
self.failover_to_offline()

@rank_zero_only
def log_system_metrics(self) -> None:
Expand Down
Loading