Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlflow_export_import/common/mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None):
:return: Experiment ID
"""
from mlflow_export_import.common import utils
if utils.importing_into_databricks():
if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS:
create_workspace_dir(dbx_client, os.path.dirname(exp_name))
try:
if not tags: tags = {}
Expand Down
32 changes: 26 additions & 6 deletions mlflow_export_import/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pandas as pd
from tabulate import tabulate
import mlflow
from enum import Enum, auto

class MLFlowImplementation(Enum):
DATABRICKS = auto()
AZURE_ML = auto()
OSS = auto()

# Databricks tags that cannot or should not be set
_DATABRICKS_SKIP_TAGS = set([
Expand All @@ -11,15 +16,26 @@
"mlflow.experiment.sourceType", "mlflow.experiment.sourceId"
])

_AZURE_ML_SKIP_TAGS = set([
"mlflow.user",
"mlflow.source.git.commit"
])


def create_mlflow_tags_for_databricks_import(tags):
if importing_into_databricks():
tags = { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS }
return tags
environment = get_import_target_implementation()
if environment == MLFlowImplementation.DATABRICKS:
return { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS }
if environment == MLFlowImplementation.AZURE_ML:
return { k:v for k,v in tags.items() if not k in _AZURE_ML_SKIP_TAGS }
if environment == MLFlowImplementation.OSS:
return tags
raise Exception("Unsupported environment")


def set_dst_user_id(tags, user_id, use_src_user_id):
if importing_into_databricks():
if get_import_target_implementation() in (MLFlowImplementation.DATABRICKS,
MLFlowImplementation.AZURE_ML):
return
from mlflow.entities import RunTag
from mlflow.utils.mlflow_tags import MLFLOW_USER
Expand Down Expand Up @@ -59,8 +75,12 @@ def nested_tags(dst_client, run_ids_mapping):
dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id)


def importing_into_databricks():
return mlflow.tracking.get_tracking_uri().startswith("databricks")
def get_import_target_implementation() -> MLFlowImplementation:
if mlflow.tracking.get_tracking_uri().startswith("databricks"):
return MLFlowImplementation.DATABRICKS
if mlflow.tracking.get_tracking_uri().startswith("azureml"):
return MLFlowImplementation.AZURE_ML
return MLFlowImplementation.OSS


def show_table(title, lst, columns):
Expand Down
4 changes: 2 additions & 2 deletions mlflow_export_import/run/import_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
self.dbx_client = DatabricksHttpClient()
self.import_source_tags = import_source_tags
print(f"in_databricks: {self.in_databricks}")
print(f"importing_into_databricks: {utils.importing_into_databricks()}")
print(f"importing_into_environment: {utils.get_import_target_implementation().name}")


def import_run(self, exp_name, input_dir, dst_notebook_dir=None):
Expand Down Expand Up @@ -93,7 +93,7 @@ def _import_run(self, dst_exp_name, input_dir, dst_notebook_dir):
import traceback
traceback.print_exc()
raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed")
if utils.importing_into_databricks() and dst_notebook_dir:
if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS and dst_notebook_dir:
ndir = os.path.join(dst_notebook_dir, run_id) if self.dst_notebook_dir_add_run_id else dst_notebook_dir
self._upload_databricks_notebook(input_dir, src_run_dct, ndir)

Expand Down
2 changes: 1 addition & 1 deletion tests/compare_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def compare_versions(mlflow_client_src, mlflow_client_dst, vr_src, vr_dst, outpu
assert vr_src.status_message == vr_dst.status_message
if mlflow_client_src != mlflow_client_src:
assert vr_src.name == vr_dst.name
if not utils.importing_into_databricks():
if utils.get_import_target_implementation() != utils.MLFlowImplementation.DATABRICKS:
assert vr_src.user_id == vr_dst.user_id

tags_dst = { k:v for k,v in vr_dst.tags.items() if not k.startswith(ExportTags.PREFIX_ROOT) }
Expand Down