From 9f6eb31380b81220bced5b6a0c6bc024e1644bed Mon Sep 17 00:00:00 2001 From: Danilo Peixoto Date: Thu, 22 Dec 2022 07:22:01 -0300 Subject: [PATCH 1/4] Get registered model from MLflow client --- mlflow_export_import/model/export_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index 001a5989..c8eeeb91 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -2,12 +2,13 @@ Export a registered model and all the experiment runs associated with each version. """ +import json import os import click import mlflow +from mlflow.utils.proto_json_utils import message_to_json from mlflow_export_import.common import MlflowExportImportException -from mlflow_export_import.common.http_client import MlflowHttpClient from mlflow_export_import.common import filesystem as _filesystem from mlflow_export_import.run.export_run import RunExporter from mlflow_export_import import utils, click_doc @@ -23,7 +24,6 @@ def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=No :param export_run: Export the run that generated a registered model's version. """ self.mlflow_client = mlflow_client - self.http_client = MlflowHttpClient() self.run_exporter = RunExporter(self.mlflow_client, export_source_tags=export_source_tags, notebook_formats=notebook_formats) self.stages = self._normalize_stages(stages) self.export_run = export_run @@ -85,7 +85,10 @@ def _export_model(self, model_name, output_dir): traceback.print_exc() output_versions.sort(key=lambda x: x["version"], reverse=False) - model = self.http_client.get(f"registered-models/get", {"name": model_name}) + model_obj = self.mlflow_client.get_registered_model(model_name) + model_proto = model_obj.to_proto() + model = json.loads(message_to_json(model_proto)) + export_info = { "export_info": { **utils.create_export_info(), **{ "num_target_stages": len(self.stages), @@ -95,7 +98,7 @@ def _export_model(self, model_name, output_dir): } } } - model = { **export_info, **model } + model = {'registered_model': model, **export_info } model["registered_model"]["latest_versions"] = output_versions print(f"Exported {exported_versions}/{len(output_versions)} versions for model '{model_name}'") From 63bb42d8f84df6f5487517140743bd06a6ebd3ff Mon Sep 17 00:00:00 2001 From: Danilo Peixoto Date: Thu, 22 Dec 2022 10:50:21 -0300 Subject: [PATCH 2/4] Make generic import --- mlflow_export_import/model/import_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index 1e536b14..db71133d 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -3,6 +3,7 @@ """ import os +from urllib.parse import urlparse import click import mlflow @@ -45,8 +46,8 @@ def _import_version(self, model_name, src_vr, dst_run_id, dst_source, sleep_time :param sleep_time: Seconds to wait for model version crreation. """ src_current_stage = src_vr["current_stage"] - dst_source = dst_source.replace("file://","") # OSS MLflow - if not dst_source.startswith("dbfs:") and not os.path.exists(dst_source): + parsed_dst_source = urlparse(dst_source) + if parsed_dst_source.scheme == "file" and not os.path.exists(dst_source): raise MlflowExportImportException(f"'source' argument for MLflowClient.create_model_version does not exist: {dst_source}") kwargs = {"await_creation_for": self.await_creation_for } if self.await_creation_for else {} tags = src_vr["tags"] From 59ee0f6c0fb70c3d992a98301a8a951f728c0215 Mon Sep 17 00:00:00 2001 From: Danilo Peixoto Date: Fri, 23 Dec 2022 10:32:38 -0300 Subject: [PATCH 3/4] Make thread safe --- mlflow_export_import/model/import_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index db71133d..e0a5ad05 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -190,7 +190,7 @@ def import_model(self, model_name, input_dir, delete_model=False, verbose=False, for vr in model_dct["latest_versions"]: src_run_id = vr["run_id"] dst_run_id = self.run_info_map[src_run_id].run_id - mlflow.set_experiment(vr["_experiment_name"]) + # mlflow.set_experiment(vr["_experiment_name"]) self.import_version(model_name, vr, dst_run_id, sleep_time) if verbose: model_utils.dump_model_versions(self.mlflow_client, model_name) From d185ae4d309d95ee3a559871608ac823e81a39e0 Mon Sep 17 00:00:00 2001 From: Danilo Peixoto Date: Fri, 23 Dec 2022 10:53:37 -0300 Subject: [PATCH 4/4] Update import_model.py --- mlflow_export_import/model/import_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index e0a5ad05..a057a8c1 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -190,7 +190,7 @@ def import_model(self, model_name, input_dir, delete_model=False, verbose=False, for vr in model_dct["latest_versions"]: src_run_id = vr["run_id"] dst_run_id = self.run_info_map[src_run_id].run_id - # mlflow.set_experiment(vr["_experiment_name"]) + # mlflow.set_experiment(vr["_experiment_name"]) Is it thread-safe? self.import_version(model_name, vr, dst_run_id, sleep_time) if verbose: model_utils.dump_model_versions(self.mlflow_client, model_name)