diff --git a/src/modelplane/runways/responder.py b/src/modelplane/runways/responder.py index 18a5f4f..d5f39f8 100644 --- a/src/modelplane/runways/responder.py +++ b/src/modelplane/runways/responder.py @@ -4,10 +4,15 @@ import tempfile import mlflow - from modelgauge.pipeline_runner import build_runner -from modelgauge.sut_registry import SUTS +from modelgauge.sut_factory import SUT_FACTORY +from modelplane.runways.data import ( + Artifact, + BaseInput, + RunArtifacts, + build_and_log_input, +) from modelplane.runways.utils import ( CACHE_DIR, MODELGAUGE_RUN_TAG_NAME, @@ -17,12 +22,6 @@ is_debug_mode, setup_sut_credentials, ) -from modelplane.runways.data import ( - Artifact, - BaseInput, - RunArtifacts, - build_and_log_input, -) def respond( @@ -37,7 +36,7 @@ def respond( prompt_text_col=None, ) -> RunArtifacts: secrets = setup_sut_credentials(sut_id) - sut = SUTS.make_instance(uid=sut_id, secrets=secrets) + sut = SUT_FACTORY.make_instance(uid=sut_id, secrets=secrets) params = {"num_workers": num_workers} tags = {"sut_id": sut_id, RUN_TYPE_TAG_NAME: RUN_TYPE_RESPONDER} @@ -58,7 +57,7 @@ def respond( input_path=input_data.local_path(), output_dir=pathlib.Path(tmp), cache_dir=None if disable_cache else CACHE_DIR, - suts={sut_id: sut}, + suts={sut.uid: sut}, prompt_uid_col=prompt_uid_col, prompt_text_col=prompt_text_col, ) diff --git a/src/modelplane/runways/utils.py b/src/modelplane/runways/utils.py index 1f9cb70..47d0912 100644 --- a/src/modelplane/runways/utils.py +++ b/src/modelplane/runways/utils.py @@ -2,7 +2,6 @@ from typing import List import mlflow - from modelgauge.annotator_registry import ANNOTATORS from modelgauge.config import ( SECRETS_PATH, @@ -10,7 +9,7 @@ raise_if_missing_from_config, ) from modelgauge.secret_values import RawSecrets -from modelgauge.sut_registry import SUTS +from modelgauge.sut_factory import SUT_FACTORY # Path to the secrets toml file SECRETS_PATH_ENV = "MODEL_SECRETS_PATH" @@ -35,7 +34,7 @@ def is_debug_mode() -> bool: def setup_sut_credentials(uid: str) -> RawSecrets: missing_secrets = [] secrets = safe_load_secrets_from_config() - missing_secrets.extend(SUTS.get_missing_dependencies(uid, secrets=secrets)) + missing_secrets.extend(SUT_FACTORY.get_missing_dependencies(uid, secrets=secrets)) raise_if_missing_from_config(missing_secrets) return secrets