Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Test files
info.txt
run_tests.log
24 changes: 24 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
image: 696577169457.dkr.ecr.eu-central-1.amazonaws.com/baseimages/python:3.7

stages:
- release

.release: &release
stage: release
script:
- pip3 install -U twine
- python3 setup.py bdist_wheel
- twine upload --repository-url https://nexus.ccl/nexus/repository/$NEXUS_REPOSITORY_NAME/ dist/*

release-prod:
extends: .release
variables:
NEXUS_REPOSITORY_NAME: 'pypi-hosted'
only:
- tags

release-snapshot:
extends: .release
variables:
NEXUS_REPOSITORY_NAME: 'pypi-hosted-snapshots'
when: manual
5 changes: 0 additions & 5 deletions databricks_notebooks/git/Import_Run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
# MAGIC #### Setup
# MAGIC * See Setup in [README]($./_README).

# COMMAND ----------

dbutils.widgets.removeAll()


# COMMAND ----------

# MAGIC %md ### Setup
Expand Down
2 changes: 2 additions & 0 deletions mlflow_export_import/click_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

model_stages = "Stages to export (comma seperated). Default is all stages. Values are Production, Staging, Archived and None."

model_versions = "Versions to export (comma seperated). Default is all versions. Values are valid integer numbers."

delete_model = "If the model exists, first delete the model and all its versions."

use_threads = "Process export/import in parallel using threads."
11 changes: 6 additions & 5 deletions mlflow_export_import/common/dump_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
INDENT = " "
MAX_LEVEL = 1
TS_FORMAT = "%Y-%m-%d_%H:%M:%S"
client = mlflow.tracking.MlflowClient()
print("MLflow Tracking URI:", mlflow.get_tracking_uri())

def dump_run(run, max_level=1, indent=""):
dump_run_info(run.info,indent)
Expand All @@ -27,17 +25,19 @@ def dump_run(run, max_level=1, indent=""):
num_bytes, num_artifacts = dump_artifacts(run.info.run_id, "", 0, max_level, indent+INDENT)
print(f"{indent}Total: bytes: {num_bytes} artifacts: {num_artifacts}")
return run, num_bytes, num_artifacts

def dump_run_id(run_id, max_level=1, indent=""):
client = mlflow.tracking.MlflowClient()
run = client.get_run(run_id)
return dump_run(run,max_level,indent)

def dump_run_info(info, indent=""):
print("{}RunInfo:".format(indent))
client = mlflow.tracking.MlflowClient()
exp = client.get_experiment(info.experiment_id)
if exp is None:
print(f"ERROR: Cannot find experiment ID '{info.experiment_id}'")
return
return
print("{} name: {}".format(indent,exp.name))
for k,v in sorted(info.__dict__.items()):
if not k.endswith("_time"):
Expand All @@ -58,8 +58,9 @@ def _dump_time(info, k, indent=""):
return v

def dump_artifacts(run_id, path, level, max_level, indent):
if level+1 > max_level:
if level+1 > max_level:
return 0,0
client = mlflow.tracking.MlflowClient()
artifacts = client.list_artifacts(run_id,path)
num_bytes, num_artifacts = (0,0)
for j,art in enumerate(artifacts):
Expand Down
9 changes: 4 additions & 5 deletions mlflow_export_import/common/find_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
import click
import mlflow

client = mlflow.tracking.MlflowClient()
print("MLflow Tracking URI:", mlflow.get_tracking_uri())

def find_artifacts(run_id, path, target, max_level=sys.maxsize):
return _find_artifacts(run_id, path, target, max_level, 0, [])

def _find_artifacts(run_id, path, target, max_level, level, matches):
if level+1 > max_level:
if level+1 > max_level:
return matches
artifacts = client.list_artifacts(run_id,path)
client = mlflow.tracking.MlflowClient()
artifacts = client.list_artifacts(run_id, path)
for art in artifacts:
#print(f"art_path: {art.path}")
filename = os.path.basename(art.path)
Expand All @@ -40,5 +39,5 @@ def main(run_id, path, target, max_level): # pragma: no cover
for x in matches:
print(" ",x)

if __name__ == "__main__":
if __name__ == "__main__":
main()
52 changes: 37 additions & 15 deletions mlflow_export_import/model/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@
from mlflow_export_import import utils, click_doc

class ModelExporter():
def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=None, stages=None, export_run=True):
def __init__(self, mlflow_client, export_source_tags=False, notebook_formats=None, stages=None, versions=None, export_run=True, host=None):
"""
:param mlflow_client: MLflow client or if None create default client.
:param export_source_tags: Export source run metadata tags.
:param notebook_formats: List of notebook formats to export. Values are SOURCE, HTML, JUPYTER or DBC.
:param stages: Stages to export. Default is all stages. Values are Production, Staging, Archived and None.
:param versions: Versions to export. Default is all versions. Values are valid integer numbers.
:param export_run: Export the run that generated a registered model's version.
:param host: Pass host to the MlflowHttpClient and RunExporter.
"""
self.mlflow_client = mlflow_client
self.http_client = MlflowHttpClient()
self.run_exporter = RunExporter(self.mlflow_client, export_source_tags=export_source_tags, notebook_formats=notebook_formats)
self.http_client = MlflowHttpClient(host=host)
self.run_exporter = RunExporter(self.mlflow_client, export_source_tags=export_source_tags, notebook_formats=notebook_formats, host=host)
self.stages = self._normalize_stages(stages)
self.versions = self._normalize_versions(versions)
self.export_run = export_run

def export_model(self, model_name, output_dir):
Expand Down Expand Up @@ -50,6 +53,8 @@ def _export_model(self, model_name, output_dir):
for vr in versions:
if len(self.stages) > 0 and not vr.current_stage.lower() in self.stages:
continue
if len(self.versions) > 0 and vr.version not in self.versions:
continue
run_id = vr.run_id
opath = os.path.join(output_dir,run_id)
opath = opath.replace("dbfs:", "/dbfs")
Expand All @@ -62,7 +67,7 @@ def _export_model(self, model_name, output_dir):
run = self.mlflow_client.get_run(run_id)
dct = dict(vr)
dct["_run_artifact_uri"] = run.info.artifact_uri
experiment = mlflow.get_experiment(run.info.experiment_id)
experiment = self.mlflow_client.get_experiment(run.info.experiment_id)
dct["_experiment_name"] = experiment.name
model["registered_model"]["latest_versions"].append(dct)
exported_versions += 1
Expand All @@ -89,14 +94,26 @@ def _normalize_stages(self, stages):
print(f"WARNING: stage '{stage}' must be one of: {model_version_stages.ALL_STAGES}")
return stages

def _normalize_versions(self, versions):
if versions is None:
return []
if isinstance(versions, str):
versions = versions.split(",")
for version in versions:
try:
int(version)
except ValueError:
print(f"WARNING: version '{version}' must be a valid number")
return versions

@click.command()
@click.option("--model",
help="Registered model name.",
@click.option("--model",
help="Registered model name.",
type=str,
required=True
)
@click.option("--output-dir",
help="Output directory.",
@click.option("--output-dir",
help="Output directory.",
type=str,
required=True
)
Expand All @@ -106,24 +123,29 @@ def _normalize_stages(self, stages):
default=False,
show_default=True
)
@click.option("--notebook-formats",
help=click_doc.notebook_formats,
@click.option("--notebook-formats",
help=click_doc.notebook_formats,
type=str,
default="",
default="",
show_default=True
)
@click.option("--stages",
help=click_doc.model_stages,
@click.option("--stages",
help=click_doc.model_stages,
type=str,
required=False
)
@click.option("--versions",
help=click_doc.model_versions,
type=str,
required=False
)

def main(model, output_dir, export_source_tags, notebook_formats, stages):
def main(model, output_dir, export_source_tags, notebook_formats, stages, versions):
print("Options:")
for k,v in locals().items():
print(f" {k}: {v}")
client = mlflow.tracking.MlflowClient()
exporter = ModelExporter(client, export_source_tags=export_source_tags, notebook_formats=utils.string_to_list(notebook_formats), stages=stages)
exporter = ModelExporter(client, export_source_tags=export_source_tags, notebook_formats=utils.string_to_list(notebook_formats), stages=stages, versions=versions)
exporter.export_model(model, output_dir)

if __name__ == "__main__":
Expand Down
66 changes: 34 additions & 32 deletions mlflow_export_import/model/import_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@

class BaseModelImporter():
""" Base class of ModelImporter subclasses. """
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None):
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None, host=None):
"""
:param mlflow_client: MLflow client or if None create default client.
:param run_importer: RunImporter instance.
:param await_creation_for: Seconds to wait for model version crreation.
:param host: Pass host to the RunImporter.
"""
self.mlflow_client = mlflow_client
self.run_importer = run_importer if run_importer else RunImporter(self.mlflow_client, mlmodel_fix=True)
self.await_creation_for = await_creation_for
self.mlflow_client = mlflow_client
self.run_importer = run_importer if run_importer else RunImporter(self.mlflow_client, mlmodel_fix=True, host=host)
self.await_creation_for = await_creation_for

def _import_version(self, model_name, src_vr, dst_run_id, dst_source, sleep_time):
"""
Expand All @@ -33,10 +34,10 @@ def _import_version(self, model_name, src_vr, dst_run_id, dst_source, sleep_time
"""
src_current_stage = src_vr["current_stage"]
dst_source = dst_source.replace("file://","") # OSS MLflow
if not dst_source.startswith("dbfs:") and not os.path.exists(dst_source):
if not dst_source.startswith("dbfs:") and not dst_source.startswith("s3:") and not os.path.exists(dst_source):
raise MlflowExportImportException(f"'source' argument for MLflowClient.create_model_version does not exist: {dst_source}")
kwargs = {"await_creation_for": self.await_creation_for } if self.await_creation_for else {}
version = self.mlflow_client.create_model_version(model_name, dst_source, dst_run_id, **kwargs)
version = self.mlflow_client.create_model_version(model_name, dst_source, dst_run_id, src_vr["tags"], **kwargs)
model_utils.wait_until_version_is_ready(self.mlflow_client, model_name, version, sleep_time=sleep_time)
if src_current_stage != "None":
self.mlflow_client.transition_model_version_stage(model_name, version.version, src_current_stage)
Expand Down Expand Up @@ -77,27 +78,29 @@ def _import_model(self, model_name, input_dir, delete_model=False, verbose=False

class ModelImporter(BaseModelImporter):
""" Low-level 'point' model importer """
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None):
super().__init__(mlflow_client, run_importer, await_creation_for=await_creation_for)
def __init__(self, mlflow_client, run_importer=None, await_creation_for=None, host=None):
super().__init__(mlflow_client, run_importer, await_creation_for=await_creation_for, host=host)

def import_model(self, model_name, input_dir, experiment_name, delete_model=False, verbose=False, sleep_time=30):
"""
:param model_name: Model name.
:param input_dir: Input directory.
:param experiment_name: The name of the experiment
:param experiment_name: The name of the experiment.
:param delete_model: Delete current model before importing versions.
:param verbose: Verbose.
:param sleep_time: Seconds to wait for model version crreation.
:return: Model import manifest.
"""
model_dct = self._import_model(model_name, input_dir, delete_model, verbose, sleep_time)
mlflow.set_experiment(experiment_name)
print("Importing versions:")
imported_run_ids = []
for vr in model_dct["latest_versions"]:
run_id = self._import_run(input_dir, experiment_name, vr)
imported_run_ids.append(run_id)
self.import_version(model_name, vr, run_id, sleep_time)
if verbose:
model_utils.dump_model_versions(self.mlflow_client, model_name)
return imported_run_ids

def _import_run(self, input_dir, experiment_name, vr):
run_id = vr["run_id"]
Expand Down Expand Up @@ -149,7 +152,6 @@ def import_model(self, model_name, input_dir, delete_model=False, verbose=False,
for vr in model_dct["latest_versions"]:
src_run_id = vr["run_id"]
dst_run_id = self.run_info_map[src_run_id].run_id
mlflow.set_experiment(vr["_experiment_name"])
self.import_version(model_name, vr, dst_run_id, sleep_time)
if verbose:
model_utils.dump_model_versions(self.mlflow_client, model_name)
Expand All @@ -173,46 +175,46 @@ def _path_join(x,y):
""" Account for DOS backslash """
path = os.path.join(x,y)
if path.startswith("dbfs:"):
path = path.replace("\\","/")
path = path.replace("\\","/")
return path

@click.command()
@click.option("--input-dir",
help="Input directory produced by export_model.py.",
@click.option("--input-dir",
help="Input directory produced by export_model.py.",
type=str,
required=True
)
@click.option("--model",
help="New registered model name.",
@click.option("--model",
help="New registered model name.",
type=str,
required=True,
required=True,
)
@click.option("--experiment-name",
help="Destination experiment name - will be created if it does not exist.",
@click.option("--experiment-name",
help="Destination experiment name - will be created if it does not exist.",
type=str,
required=True
)
@click.option("--delete-model",
help=click_doc.delete_model,
@click.option("--delete-model",
help=click_doc.delete_model,
type=bool,
default=False,
default=False,
show_default=True
)
@click.option("--await-creation-for",
help="Await creation for specified seconds.",
type=int,
default=None,
@click.option("--await-creation-for",
help="Await creation for specified seconds.",
type=int,
default=None,
show_default=True
)
@click.option("--sleep-time",
help="Sleep time for polling until version.status==READY.",
@click.option("--sleep-time",
help="Sleep time for polling until version.status==READY.",
type=int,
default=5,
)
@click.option("--verbose",
help="Verbose.",
type=bool,
default=False,
@click.option("--verbose",
help="Verbose.",
type=bool,
default=False,
show_default=True
)

Expand Down
Loading