diff --git a/mlflow_export_import/logged_model/import_logged_model.py b/mlflow_export_import/logged_model/import_logged_model.py index fea04a0..73a7f36 100644 --- a/mlflow_export_import/logged_model/import_logged_model.py +++ b/mlflow_export_import/logged_model/import_logged_model.py @@ -18,6 +18,7 @@ from mlflow_export_import.logged_model.logged_model_utils import update_logged_model_mlmodel_data from mlflow_export_import.logged_model.logged_model_importer import _import_inputs, _log_metrics from mlflow_export_import.common.version_utils import has_logged_model_support +from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client _logger = utils.getLogger(__name__) @@ -48,7 +49,8 @@ def import_logged_model( _logger.info(f"Importing logged model from '{input_dir}'") - exp = mlflow_utils.set_experiment(mlflow_client, None, experiment_name) + dbx_client = create_dbx_client(mlflow_client) + exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) src_logged_model_path = os.path.join(input_dir, "logged_model.json") src_logged_model_dct = io_utils.read_file_mlflow(src_logged_model_path) logged_model = None @@ -127,4 +129,4 @@ def main(input_dir, if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/mlflow_export_import/trace/import_trace.py b/mlflow_export_import/trace/import_trace.py index 21fa226..f13fd85 100644 --- a/mlflow_export_import/trace/import_trace.py +++ b/mlflow_export_import/trace/import_trace.py @@ -21,6 +21,7 @@ ) from mlflow_export_import.trace.trace_utils import _try_parse_json, _get_span_attributes from mlflow_export_import.common.version_utils import has_trace_support +from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client _logger = utils.getLogger(__name__) @@ -43,7 +44,8 @@ def import_trace( mlflow_client = mlflow_client or create_mlflow_client() - exp = mlflow_utils.set_experiment(mlflow_client, None, experiment_name) + dbx_client = create_dbx_client(mlflow_client) + exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) src_trace_path = os.path.join(input_dir, "trace.json") src_trace_dct = io_utils.read_file_mlflow(src_trace_path) @@ -151,4 +153,4 @@ def main(experiment_name, input_dir): ) if __name__ == "__main__": - main() \ No newline at end of file + main()