From 839a0777ea175f5bb4120be7360f5da4ad440143 Mon Sep 17 00:00:00 2001 From: Matt Sharp Date: Thu, 14 Nov 2024 10:41:09 -0700 Subject: [PATCH] Importing into a Sagemaker MLFlow instance --- .gitignore | 1 + mlflow_export_import/bulk/import_models.py | 11 +++++++++++ mlflow_export_import/client/mlflow_auth_utils.py | 5 +---- mlflow_export_import/common/utils.py | 3 +++ .../experiment/import_experiment.py | 2 +- .../model_version/import_model_version.py | 2 +- mlflow_export_import/run/import_run.py | 14 ++++++++------ 7 files changed, 26 insertions(+), 12 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..c18dd8d8 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index 0859ddc9..a1c0305a 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -2,6 +2,7 @@ Imports models and their experiments and runs. """ +import re import os import time import json @@ -152,6 +153,7 @@ def _import_models(mlflow_client, for model_name in model_names: dir = os.path.join(models_dir, model_name) model_name = rename_utils.rename(model_name, model_renames, "model") + model_name = validate_model_name(model_name) executor.submit(all_importer.import_model, model_name = model_name, input_dir = dir, @@ -163,6 +165,15 @@ def _import_models(mlflow_client, return { "models": len(model_names), "duration": duration } +def validate_model_name(model_name): + pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,56}$" + if not re.match(pattern, model_name): + _logger.warning( + f"The model name does match the pattern {pattern} and may not import " + "correctly for some environments." + ) + + @click.command() @opt_input_dir @opt_delete_model diff --git a/mlflow_export_import/client/mlflow_auth_utils.py b/mlflow_export_import/client/mlflow_auth_utils.py index 18e5b198..5f4d1dff 100644 --- a/mlflow_export_import/client/mlflow_auth_utils.py +++ b/mlflow_export_import/client/mlflow_auth_utils.py @@ -20,10 +20,7 @@ def get_mlflow_host_token(): uri = mlflow.tracking.get_tracking_uri() if uri: if not uri.startswith("databricks"): - if not uri.startswith("http"): - _raise_exception(uri) - else: - return (uri, None) + return (uri, None) else: _raise_exception(uri) diff --git a/mlflow_export_import/common/utils.py b/mlflow_export_import/common/utils.py index aa410449..78879bdd 100644 --- a/mlflow_export_import/common/utils.py +++ b/mlflow_export_import/common/utils.py @@ -18,6 +18,7 @@ def calling_databricks(dbx_client=None): """ from mlflow_export_import.client.http_client import DatabricksHttpClient from mlflow_export_import.common import MlflowExportImportException + from requests.exceptions import RequestException global _calling_databricks if _calling_databricks is None: @@ -27,6 +28,8 @@ def calling_databricks(dbx_client=None): _calling_databricks = True except MlflowExportImportException: _calling_databricks = False + except RequestException: + _calling_databricks = False _logger.info(f"Calling Databricks: {_calling_databricks}") return _calling_databricks diff --git a/mlflow_export_import/experiment/import_experiment.py b/mlflow_export_import/experiment/import_experiment.py index 72a14c5a..dbe4c341 100644 --- a/mlflow_export_import/experiment/import_experiment.py +++ b/mlflow_export_import/experiment/import_experiment.py @@ -1,5 +1,5 @@ """ -Exports an experiment to a directory. +Imports an experiment to a directory. """ import os diff --git a/mlflow_export_import/model_version/import_model_version.py b/mlflow_export_import/model_version/import_model_version.py index 6d720a53..34e5fe6e 100644 --- a/mlflow_export_import/model_version/import_model_version.py +++ b/mlflow_export_import/model_version/import_model_version.py @@ -108,7 +108,7 @@ def _import_model_version( ): start_time = time.time() dst_source = dst_source.replace("file://","") # OSS MLflow - if not dst_source.startswith("dbfs:") and not os.path.exists(dst_source): + if not (dst_source.startswith("dbfs:") or dst_source.startswith("s3:")) and not os.path.exists(dst_source): raise MlflowExportImportException(f"'source' argument for MLflowClient.create_model_version does not exist: {dst_source}", http_status_code=404) tags = src_vr["tags"] diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index ad66a1ae..173f7d6c 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -7,7 +7,7 @@ import base64 from mlflow.entities.lifecycle_stage import LifecycleStage -from mlflow.entities import RunStatus +from mlflow.entities import Dataset, DatasetInput, InputTag, RunStatus from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID from mlflow_export_import.common.click_options import ( @@ -81,7 +81,7 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): use_src_user_id, in_databricks ) - _import_inputs(http_client, src_run_dct, run_id) + _import_inputs(mlflow_client, src_run_dct, run_id) path = _fs.mk_local_path(os.path.join(input_dir, "artifacts")) if os.path.exists(path): @@ -141,10 +141,12 @@ def _upload_databricks_notebook(dbx_client, input_dir, src_run_dct, dst_notebook _logger.warning(f"Cannot save notebook '{dst_notebook_path}'. {e}") -def _import_inputs(http_client, src_run_dct, run_id): - inputs = src_run_dct.get("inputs") - dct = { "run_id": run_id, "datasets": inputs } - http_client.post("runs/log-inputs", dct) +def _import_inputs(mlflow_client, src_run_dct, run_id): + inputs = src_run_dct.get("inputs", []) + if not inputs: + return + dataset_inputs = [DatasetInput(Dataset.from_dictionary(input['dataset']), [InputTag.from_dictionary(tag) for tag in input['tags']]) for input in inputs] + mlflow_client.log_inputs(run_id=run_id, datasets=dataset_inputs) @click.command()