Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlflow_export_import/logged_model/import_logged_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mlflow_export_import.logged_model.logged_model_utils import update_logged_model_mlmodel_data
from mlflow_export_import.logged_model.logged_model_importer import _import_inputs, _log_metrics
from mlflow_export_import.common.version_utils import has_logged_model_support
from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client

_logger = utils.getLogger(__name__)

Expand Down Expand Up @@ -48,7 +49,8 @@ def import_logged_model(

_logger.info(f"Importing logged model from '{input_dir}'")

exp = mlflow_utils.set_experiment(mlflow_client, None, experiment_name)
dbx_client = create_dbx_client(mlflow_client)
exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name)
src_logged_model_path = os.path.join(input_dir, "logged_model.json")
src_logged_model_dct = io_utils.read_file_mlflow(src_logged_model_path)
logged_model = None
Expand Down Expand Up @@ -127,4 +129,4 @@ def main(input_dir,


if __name__ == "__main__":
main()
main()
6 changes: 4 additions & 2 deletions mlflow_export_import/trace/import_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from mlflow_export_import.trace.trace_utils import _try_parse_json, _get_span_attributes
from mlflow_export_import.common.version_utils import has_trace_support
from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client

_logger = utils.getLogger(__name__)

Expand All @@ -43,7 +44,8 @@ def import_trace(

mlflow_client = mlflow_client or create_mlflow_client()

exp = mlflow_utils.set_experiment(mlflow_client, None, experiment_name)
dbx_client = create_dbx_client(mlflow_client)
exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name)
src_trace_path = os.path.join(input_dir, "trace.json")
src_trace_dct = io_utils.read_file_mlflow(src_trace_path)

Expand Down Expand Up @@ -151,4 +153,4 @@ def main(experiment_name, input_dir):
)

if __name__ == "__main__":
main()
main()