From ed3995b0dd27b74937288b53ef77df91d00c288d Mon Sep 17 00:00:00 2001 From: Vishakh Pillai Date: Fri, 17 Oct 2025 17:35:37 -0400 Subject: [PATCH] help with TEMPORARILY_UNAVAILABLE mlflow API error by logging metrics using batch & retry --- src/edvise/modeling/bias_detection.py | 95 ++++++++++++++++++++++++--- 1 file changed, 87 insertions(+), 8 deletions(-) diff --git a/src/edvise/modeling/bias_detection.py b/src/edvise/modeling/bias_detection.py index ecab33b3..47e78640 100644 --- a/src/edvise/modeling/bias_detection.py +++ b/src/edvise/modeling/bias_detection.py @@ -1,9 +1,11 @@ import logging +import time +import random + import typing as t import matplotlib.figure import matplotlib.pyplot as plt -import mlflow import numpy as np from collections import Counter import pandas as pd @@ -11,6 +13,10 @@ import seaborn as sns import sklearn.metrics +import mlflow +from mlflow.entities import Metric +from mlflow.exceptions import RestException + from . import evaluation LOGGER = logging.getLogger(__name__) @@ -627,18 +633,33 @@ def log_subgroup_metrics_to_mlflow( group_col: str, ) -> None: """ - Logs individual subgroup-level metrics to MLflow. + Logs subgroup-level bias and performance metrics to MLflow in a single batch. + + This function aggregates subgroup metrics into a single payload and logs them + to MLflow using the `_log_metrics_batch_with_retry()` helper, which batches + metrics into one request and automatically retries transient failures. This + reduces API call volume and mitigates 'TEMPORARILY_UNAVAILABLE' errors caused + by rapid per-metric logging. Args: - subgroup_metrics: Dictionary of subgroup bias metrics. - split_name: Name of the data split (e.g., "train", "test", "validation"). - group_col: Column name representing the group for bias evaluation. + subgroup_metrics (dict): Dictionary of subgroup bias or performance metrics. + split_name (str): Name of the data split (e.g., "train", "test", "validation"). + group_col (str): Column name representing the demographic or grouping variable + used for bias evaluation. """ + payload = {} for metric, value in subgroup_metrics.items(): if metric not in {"Subgroup", "Number of Samples"}: - mlflow.log_metric( - f"{split_name}_{group_col}_metrics/{metric}_subgroup", value - ) + key = f"{split_name}_{group_col}_metrics/{metric}_subgroup" + payload[key] = value + + active_run = mlflow.active_run() + run_id = active_run.info.run_id if active_run else None + if run_id is None: + with mlflow.start_run(nested=True) as r: + _log_metrics_batch_with_retry(r.info.run_id, payload) + else: + _log_metrics_batch_with_retry(run_id, payload) def plot_fnr_group(fnr_data: list) -> matplotlib.figure.Figure: @@ -692,3 +713,61 @@ def plot_fnr_group(fnr_data: list) -> matplotlib.figure.Figure: plt.tight_layout() return fig + + +def _log_metrics_batch_with_retry( + run_id, metrics_dict, step=0, max_tries=6, base_delay=0.5 +): + """ + Log multiple MLflow metrics in a single batch with retry and backoff. + + This function batches metrics into one REST call using MlflowClient.log_batch() + to reduce API call volume and avoid transient 'TEMPORARILY_UNAVAILABLE' errors + from the MLflow tracking server. It automatically retries failed requests using + exponential backoff with jitter and, on repeated failure, falls back to logging + the metrics as a CSV artifact to preserve data. + + Parameters: + run_id (str): Active MLflow run ID. + metrics_dict (dict): Mapping of metric names to numeric values. + step (int, optional): Metric step value. Defaults to 0. + max_tries (int, optional): Maximum number of retry attempts. Defaults to 6. + base_delay (float, optional): Base delay (in seconds) for exponential backoff. Defaults to 0.5. + + Raises: + mlflow.exceptions.RestException: If all retry attempts fail. + """ + client = mlflow.tracking.MlflowClient() + ts = int(time.time() * 1000) + batch = [ + Metric(key=k, value=float(v), timestamp=ts, step=step) + for k, v in metrics_dict.items() + ] + last_err = None + for attempt in range(1, max_tries + 1): + try: + client.log_batch(run_id, metrics=batch) + return + except RestException as e: + last_err = e + # Retry only on transient cases + if ( + "TEMPORARILY_UNAVAILABLE" in str(e) + or "rate limit" in str(e).lower() + or "temporarily unavailable" in str(e).lower() + ): + sleep_s = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 0.2) + time.sleep(min(sleep_s, 8.0)) + continue + raise + # If we get here, retries failed — fall back to artifact so the run isn’t lost + try: + import io + + buf = io.StringIO() + for k, v in metrics_dict.items(): + buf.write(f"{ts},{step},{k},{v}\n") + mlflow.log_text(buf.getvalue(), "fallback_metrics.csv") + except Exception: + pass + raise last_err