diff --git a/mlflow_export_import/common/mlflow_utils.py b/mlflow_export_import/common/mlflow_utils.py index 44f82e90..95485ae2 100644 --- a/mlflow_export_import/common/mlflow_utils.py +++ b/mlflow_export_import/common/mlflow_utils.py @@ -43,7 +43,7 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None): :return: Experiment ID """ from mlflow_export_import.common import utils - if utils.importing_into_databricks(): + if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS: create_workspace_dir(dbx_client, os.path.dirname(exp_name)) try: if not tags: tags = {} diff --git a/mlflow_export_import/common/utils.py b/mlflow_export_import/common/utils.py index aa40395c..5c525120 100644 --- a/mlflow_export_import/common/utils.py +++ b/mlflow_export_import/common/utils.py @@ -1,7 +1,12 @@ import pandas as pd from tabulate import tabulate import mlflow +from enum import Enum, auto +class MLFlowImplementation(Enum): + DATABRICKS = auto() + AZURE_ML = auto() + OSS = auto() # Databricks tags that cannot or should not be set _DATABRICKS_SKIP_TAGS = set([ @@ -11,15 +16,26 @@ "mlflow.experiment.sourceType", "mlflow.experiment.sourceId" ]) +_AZURE_ML_SKIP_TAGS = set([ + "mlflow.user", + "mlflow.source.git.commit" + ]) + def create_mlflow_tags_for_databricks_import(tags): - if importing_into_databricks(): - tags = { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS } - return tags + environment = get_import_target_implementation() + if environment == MLFlowImplementation.DATABRICKS: + return { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS } + if environment == MLFlowImplementation.AZURE_ML: + return { k:v for k,v in tags.items() if not k in _AZURE_ML_SKIP_TAGS } + if environment == MLFlowImplementation.OSS: + return tags + raise Exception("Unsupported environment") def set_dst_user_id(tags, user_id, use_src_user_id): - if importing_into_databricks(): + if get_import_target_implementation() in (MLFlowImplementation.DATABRICKS, + MLFlowImplementation.AZURE_ML): return from mlflow.entities import RunTag from mlflow.utils.mlflow_tags import MLFLOW_USER @@ -59,8 +75,12 @@ def nested_tags(dst_client, run_ids_mapping): dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id) -def importing_into_databricks(): - return mlflow.tracking.get_tracking_uri().startswith("databricks") +def get_import_target_implementation() -> MLFlowImplementation: + if mlflow.tracking.get_tracking_uri().startswith("databricks"): + return MLFlowImplementation.DATABRICKS + if mlflow.tracking.get_tracking_uri().startswith("azureml"): + return MLFlowImplementation.AZURE_ML + return MLFlowImplementation.OSS def show_table(title, lst, columns): diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index bb9f12ab..93d57e49 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -54,7 +54,7 @@ def __init__(self, self.dbx_client = DatabricksHttpClient() self.import_source_tags = import_source_tags print(f"in_databricks: {self.in_databricks}") - print(f"importing_into_databricks: {utils.importing_into_databricks()}") + print(f"importing_into_environment: {utils.get_import_target_implementation().name}") def import_run(self, exp_name, input_dir, dst_notebook_dir=None): @@ -93,7 +93,7 @@ def _import_run(self, dst_exp_name, input_dir, dst_notebook_dir): import traceback traceback.print_exc() raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed") - if utils.importing_into_databricks() and dst_notebook_dir: + if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS and dst_notebook_dir: ndir = os.path.join(dst_notebook_dir, run_id) if self.dst_notebook_dir_add_run_id else dst_notebook_dir self._upload_databricks_notebook(input_dir, src_run_dct, ndir) diff --git a/tests/compare_utils.py b/tests/compare_utils.py index 720a31c4..b2a6d2c9 100644 --- a/tests/compare_utils.py +++ b/tests/compare_utils.py @@ -93,7 +93,7 @@ def compare_versions(mlflow_client_src, mlflow_client_dst, vr_src, vr_dst, outpu assert vr_src.status_message == vr_dst.status_message if mlflow_client_src != mlflow_client_src: assert vr_src.name == vr_dst.name - if not utils.importing_into_databricks(): + if utils.get_import_target_implementation() != utils.MLFlowImplementation.DATABRICKS: assert vr_src.user_id == vr_dst.user_id tags_dst = { k:v for k,v in vr_dst.tags.items() if not k.startswith(ExportTags.PREFIX_ROOT) }