diff --git a/.github/workflows/deploy-test-studio-kind.yml b/.github/workflows/deploy-test-studio-kind.yml index afb89640..339d7816 100644 --- a/.github/workflows/deploy-test-studio-kind.yml +++ b/.github/workflows/deploy-test-studio-kind.yml @@ -41,7 +41,7 @@ jobs: # populate-studio, or this workflow file. # Explicitly skip docs-only changes. if echo "$CHANGED_FILES" | grep -qE \ - "^(operators/|deployment-scripts/|populate-studio/|geospatial-studio/|deploy_studio_k8s\.sh|requirements\.txt|\.github/workflows/deploy-test-studio\.yml)"; then + "^(operators/|deployment-scripts/|populate-studio/|geospatial-studio/|deploy_studio_k8s\.sh|requirements\.txt|\.github/workflows/deploy-test-studio(-kind)?\.yml)"; then echo "deploy=true" >> $GITHUB_OUTPUT echo "Deployment-relevant files changed – deployment will proceed" else @@ -370,6 +370,30 @@ jobs: exit 1 fi + - name: Run Integration Tests + env: + GATEWAY_TLS_VERIFY: "0" + BASE_GATEWAY_URL: "https://localhost:4181" + run: | + echo "=== Running Integration Tests ===" + + # Extract API key + if [[ -f ".studio-api-key" ]]; then + source .studio-api-key + export API_KEY=$STUDIO_API_KEY + echo "✅ Loaded API key" + else + echo "❌ Error: .studio-api-key file not found" + exit 1 + fi + + # Install test dependencies if not already installed + pip install -r requirements-dev.txt + + # Run integration tests + python -m pytest -q -m integration --no-cov --log-file=run.log --log-file-level=INFO tests/integration/test_inference_models.py + + - name: Run Workshop Labs if: success() env: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..b831c9e2 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,23 @@ +# © Copyright IBM Corporation 2025 +# SPDX-License-Identifier: Apache-2.0 + + +# For Python 3.11 +-r requirements.txt + +black +flake8 +hunter==3.7.0 +IPython>=8.18,<9.0.0 +rich>=14.0.0 + +# Needed by conftest +fastapi +httpx +sqlalchemy +sqlalchemy-utils + +# Test +pytest +pytest-cov +pre-commit diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..195567e0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,163 @@ +import os +import shlex +import subprocess +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +import pytest +from dotenv import find_dotenv, load_dotenv +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy_utils import create_database, database_exists, drop_database + +from tests.integration.gateway import GatewayApiClient + +from .integration.utils import make_timestamped_name + + +# ------------------------------- +# Common configurations for Integration and Unit tests +# ------------------------------- +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config): + # # --- load env and set test DB (unit tests support) --- + # load_dotenv(".env", override=True) + # db_url = os.environ.get("DATABASE_URI", str(settings.DATABASE_URI)) + "_test" + # settings.DATABASE_URI = db_url + # settings.AUTH_ENABLED = False + # + # --- register markers (integration tests support) --- + config.addinivalue_line( + "markers", "integration: marks tests that hit live external services" + ) + + +# ------------------------------- +# Unit Tests Support +# ------------------------------- +def _db_session(): + db_url = str(settings.DATABASE_URI) + if not database_exists(db_url): + create_database(db_url) + + engine = create_engine(db_url) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + return db_url, TestingSessionLocal, engine + + +@pytest.fixture(scope="session") +def db(): + """Fixture sets up and tears down PostgreSQL test database.""" + db_url, TestingSessionLocal, engine = _db_session() + Base.metadata.create_all(bind=engine) + + session = TestingSessionLocal() + yield session + + session.close_all() + drop_database(db_url) + + +def override_get_db(): + try: + _, TestingSessionLocal, _ = _db_session() + db = TestingSessionLocal() + yield db + finally: + db.close() + + +@pytest.fixture(scope="module") +def client(db): + """Sets up FastAPI test client for sending HTTP requests during testing.""" + app.dependency_overrides[get_db] = override_get_db + client = TestClient(app) + return client + + +@pytest.fixture(scope="session") +def repo_root() -> Path: + toplevel = ( + subprocess.check_output(shlex.split("git rev-parse --show-toplevel")) + .decode() + .strip() + ) + return Path(toplevel) + + +def pytest_addoption(parser): + # https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.add_argument + parser.addoption( + "--tune-id", + action="store", + default=None, + help="A ID of a geotune" "from the Fine Tune Service", + ) + + +@pytest.fixture(scope="session") +def tune_id(pytestconfig): + tune_id = pytestconfig.getoption("tune_id") + if not tune_id: + pytest.skip("[unit] skipped due to missing tune-id") + return tune_id + + +@pytest.fixture(scope="session") +def token() -> Optional[str]: + """ + Returns a valid option for IBM Verify to authenticate. Only used for interactive + integration tests. This value can be provided by .envrc or .env + """ + token = os.environ.get("TOKEN") + if not token: + pytest.skip("[unit] skipped due to missing authentication TOKEN in environment") + return token + + +# ------------------------------- +# Integration Tests Support (APIs) +# ------------------------------- +def _int_env(name: str, default: str | None = None, required: bool = False) -> str: + v = os.getenv(name, default) + if required and not v: + pytest.skip(f"[integration] Missing env var {name}; skipping") + return v or "" + + +@pytest.fixture(scope="session") +def gateway() -> GatewayApiClient: + """ + External Gateway client. + Requires: + BASE_GATEWAY_URL and API_KEY in .env file + """ + + # Make sure .env is found whether you run from repo root or a subfolder + load_dotenv(find_dotenv(usecwd=True), override=True) + + base_url = os.getenv("BASE_GATEWAY_URL") + api_key = os.getenv("API_KEY") + + if not base_url or not api_key: + pytest.skip("[integration] Missing BASE_GATEWAY_URL or API_KEY; skipping") + + # Normalize possible CRLF or stray quotes from copy/paste + api_key = api_key.strip().strip('"').strip("'") + if api_key.endswith("\r"): + api_key = api_key[:-1] + + return GatewayApiClient(base_url=base_url, api_key=api_key) + + +@pytest.fixture(scope="function") +def name_factory(): + """One fixed timestamp per test so related names match.""" + fixed_now = datetime.now(timezone.utc) + + def _make(base: str, *, prefix: str | None = None, ext: str = "") -> str: + return make_timestamped_name(base, prefix=prefix, ext=ext, now=fixed_now) + + return _make diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/data/__init__.py b/tests/integration/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/data/api_inference_models_and_inference.py b/tests/integration/data/api_inference_models_and_inference.py new file mode 100644 index 00000000..964e0a31 --- /dev/null +++ b/tests/integration/data/api_inference_models_and_inference.py @@ -0,0 +1,50 @@ +""" +Example payloads for integration tests. +Each payload is a Python dict, instead of JSON files. +""" + +# Base valid model payload +SANDBOX_MODEL = { + "display_name": "integration-test-sandbox-model", + "description": "[Integration Test 175933_01we_10oct_25] Early-access test model made available for demonstration or limited user evaluation. These models may include incomplete features or evolving performance characteristics and are intended for feedback and experimentation before full deployment.", + "pipeline_steps": [ + {"status": "READY", "process_id": "url-connector", "step_number": 0}, + {"status": "WAITING", "process_id": "push-to-geoserver", "step_number": 1}, + ], + "geoserver_push": [], + "model_input_data_spec": [ + { + "bands": [], + "connector": "sentinelhub", + "collection": "hls_s30", + "file_suffix": "S2Hand", + } + ], + "postprocessing_options": {}, + "sharable": False, + "model_onboarding_config": { + "fine_tuned_model_id": "", + "model_configs_url": "", + "model_checkpoint_url": "", + }, + "latest": True, + "version": 1.0, +} + +DEFAULT_ONBOARD_INFERENCE_MODEL = { + "model_framework": "terratorch", + "model_id": "string", + "model_name": "string", + "model_configs_url": "https://example.com/", + "model_checkpoint_url": "https://example.com/", + "deployment_type": "gpu", + "resources": { + "requests": {"cpu": "6", "memory": "16G"}, + "limits": {"cpu": "12", "memory": "32G"}, + }, + "gpu_resources": { + "requests": {"nvidia.com/gpu": "1"}, + "limits": {"nvidia.com/gpu": "1"}, + }, + "inference_container_image": "", +} diff --git a/tests/integration/gateway.py b/tests/integration/gateway.py new file mode 100644 index 00000000..458b705d --- /dev/null +++ b/tests/integration/gateway.py @@ -0,0 +1,126 @@ +import os +from typing import Any, Dict, Optional +from urllib.parse import urljoin, urlparse + +import requests +from dotenv import find_dotenv, load_dotenv +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + + +class GatewayApiClient: + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + timeout: int = 60, + verify: Optional[bool | str] = False, + proxies: Optional[Dict[str, str]] = None, + ): + self.base_url = base_url.rstrip("/") + "/" + self.timeout = timeout + self.verify = verify + self.proxies = proxies + + self.session = requests.Session() + retry = Retry( + total=5, + backoff_factor=0.3, + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=frozenset(("GET", "POST", "PATCH", "DELETE")), + ) + adapter = HTTPAdapter(max_retries=retry) + self.session.mount("https://", adapter) + self.session.mount("http://", adapter) + + self.session.headers.update({"Accept": "application/json"}) + + if api_key: + key = api_key.strip().strip('"').strip("'") + if key.endswith("\r"): + key = key[:-1] + self.session.headers.update({"X-Api-Key": key}) + + def _url(self, path: str) -> str: + return urljoin(self.base_url, path.lstrip("/")) + + # Gateway methods + def get(self, path: str, **kwargs) -> requests.Response: + return self.session.get( + self._url(path), + timeout=self.timeout, + verify=self.verify, + proxies=self.proxies, + **kwargs, + ) + + def post(self, path: str, json: Any | None = None, **kwargs) -> requests.Response: + headers = {**self.session.headers, "Content-Type": "application/json"} + return self.session.post( + self._url(path), + headers=headers, + json=json or {}, + timeout=self.timeout, + verify=self.verify, + proxies=self.proxies, + **kwargs, + ) + + def put(self, path: str, json: Any | None = None, **kwargs) -> requests.Response: + user_headers = kwargs.pop("headers", None) or {} + files = kwargs.pop("files", None) + + headers = {**self.session.headers, **user_headers} + + if files is not None: + # multipart/form-data PUT + return self.session.put( + self._url(path), + headers=headers, + files=files, + data=kwargs.pop("data", {}), + timeout=self.timeout, + verify=self.verify, + proxies=self.proxies, + **kwargs, + ) + else: + # JSON PUT + headers.setdefault("Content-Type", "application/json") + return self.session.put( + self._url(path), + headers=headers, + json=json or {}, + timeout=self.timeout, + verify=self.verify, + proxies=self.proxies, + **kwargs, + ) + + def patch(self, path: str, json: Any | None = None, **kwargs) -> requests.Response: + headers = {**self.session.headers, "Content-Type": "application/json"} + return self.session.patch( + self._url(path), + headers=headers, + json=json or {}, + timeout=self.timeout, + verify=self.verify, + proxies=self.proxies, + **kwargs, + ) + + def delete(self, path: str, **kwargs) -> requests.Response: + return self.session.delete( + self._url(path), + timeout=self.timeout, + verify=self.verify, + proxies=self.proxies, + **kwargs, + ) + + @classmethod + def from_env(cls) -> "GatewayApiClient": + load_dotenv(find_dotenv(usecwd=True)) + base = os.environ["BASE_GATEWAY_URL"] + api_key = os.environ["API_KEY"] # uses the key *value* (pak-...), not the id + return cls(base_url=base, api_key=api_key) diff --git a/tests/integration/test_inference_models.py b/tests/integration/test_inference_models.py new file mode 100644 index 00000000..269dc6ae --- /dev/null +++ b/tests/integration/test_inference_models.py @@ -0,0 +1,579 @@ +import logging +import os + +import pytest + +from .data import api_inference_models_and_inference as payloads +from .utils import redacted_response_text + +log = logging.getLogger("gateway_tests") +log.setLevel(logging.INFO) + +pytestmark = pytest.mark.integration + + +# ====================== Helper Functions ======================== +def env_eval(name: str, default: str = "") -> bool: + val = os.getenv(name, default).strip().lower() + return val in {"1", "true", "yes", "y", "on"} + + +# ====================== Fixtures ================================ +@pytest.fixture() +def amo_tests_activate(): + if not env_eval("ALLOW_AMO_TESTS"): + msg = "\nSet ALLOW_AMO_TESTS=1 to run AMO tasks" + log.info(msg) + pytest.skip(msg) + + +@pytest.fixture() +def create_model(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _create_model(model_payload): + + # PAYLOAD + payload = model_payload + + # QUERY + r = gateway.post("/v2/models", json=model_payload) + log.info( + "POST /v2/models \nPayload:\n%s \nResponse (redacted)(%s):\n%s", + payload, + r.status_code, + redacted_response_text(r), + ) + assert r.status_code == 201 + body = r.json() + model_id = body.get("id") + assert model_id + return body + + return _create_model + + +@pytest.fixture() +def list_models(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _list_models(**overrides): + + # PARAMS + params = {"limit": 25, "skip": 0, **overrides} + + # QUERY + r = gateway.get("/v2/models", params=params) + log.info( + "GET /v2/models -> (%s)\nPARAMS\n%s \nResponse (redacted)(%s)", + r.status_code, + params, + redacted_response_text(r), + ) + assert r.status_code == 200 + body = r.json() + assert isinstance(body, dict), type(body) + assert "results" in body and isinstance(body["results"], list) + return body + + return _list_models + + +@pytest.fixture() +def deploy_model_with_amo(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _deploy_model_with_amo(model_id, **overrides): + + # PARAMS + # model_id + params = { + "fine_tuned_model_id": model_id, + "model_configs_url": "", # Need url + "model_checkpoint_url": "", # Need url + **overrides, + } + + # QUERY + r = gateway.post(f"/v2/models/{model_id}/deploy", params=params) + log.info( + "POST /v2/models/%s/deploy -> (%s)\nPARAMS\n%s \nResponse (redacted)(%s)", + model_id, + r.status_code, + params, + redacted_response_text(r), + ) + + body = r.json() + assert isinstance(body, dict), type(body) + return body, r + + return _deploy_model_with_amo + + +@pytest.fixture() +def update_model(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _update_model(model_id: str, *, replace: bool = False, **overrides): + + # PARAMS + # model_id + + # PAYLOAD + defaults = { + "display_name": "", + "description": "string", + "model_url": "https://example.com/", + "pipeline_steps": [{"additionalProp1": {}}], + "geoserver_push": [{"additionalProp1": {}}], + "model_input_data_spec": [{"additionalProp1": {}}], + "postprocessing_options": {"additionalProp1": {}}, + "sharable": True, + "model_onboarding_config": { + "fine_tuned_model_id": "string", + "model_configs_url": "string", + "model_checkpoint_url": "string", + }, + "latest": True, + **overrides, + } + payload = overrides if replace else {**defaults, **overrides} + + # QUERY + r = gateway.patch(f"/v2/models/{model_id}", json=payload) + log.info( + "PATCH /v2/models/%s -> (%s)\nPayload\n%s \nResponse (redacted)(%s)", + model_id, + r.status_code, + payload, + redacted_response_text(r), + ) + + body = r.json() + assert isinstance(body, dict), type(body) + return body, r + + return _update_model + + +@pytest.fixture() +def get_model(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _get_model(model_id): + + # PARAMS + # model_id + + # QUERY + r = gateway.get(f"/v2/models/{model_id}") + log.info( + "GET /v2/models/%s -> (%s)\nResponse (redacted)(%s)", + model_id, + r.status_code, + redacted_response_text(r), + ) + + body = r.json() + assert isinstance(body, dict), type(body) + return body, r + + return _get_model + + +@pytest.fixture() +def delete_model(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _delete_model(model_id): + + # PARAMS + # model_id + + # QUERY + r = gateway.delete(f"/v2/models/{model_id}") + log.info( + "\nDELETE /v2/models/%s -> (%s)\nResponse (redacted)(%s)", + model_id, + r.status_code, + redacted_response_text(r), + ) + return r + + return _delete_model + + +@pytest.fixture() +def retrieve_amo_task(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _retrieve_amo_task(model_id): + + # PARAMS + # model_id + + # QUERY + r = gateway.get(f"/v2/amo-tasks/{model_id}") + log.info( + "\nGET /v2/amo-tasks/%s -> (%s)\nResponse (redacted)(%s)", + model_id, + r.status_code, + redacted_response_text(r), + ) + body = r.json() + + return body, r + + return _retrieve_amo_task + + +@pytest.fixture() +def offboard_inference_model(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _offboard_inference_model(model_id): + + # PARAMS + # model_id + + # QUERY + r = gateway.delete(f"/v2/amo-tasks/{model_id}") + log.info( + "\nDELETE /v2/amo-tasks/%s -> (%s)\nResponse (redacted)(%s)", + model_id, + r.status_code, + redacted_response_text(r), + ) + + body = r.json() + assert isinstance(body, dict), type(body) + return body, r + + return _offboard_inference_model + + +@pytest.fixture() +def onboard_inference_model(gateway): + """ + Call /v2/datasets and return the parsed body. + Each call will hit the endpoint afresh. + """ + + def _onboard_inference_model(*, replace: bool = False, **overrides): + + # PAYLOAD + DEFAULT_ONBOARD_INFERENCE_MODEL = { + "model_framework": "terratorch", + "model_id": "string", + "model_name": "string", + "model_configs_url": "https://example.com/", + "model_checkpoint_url": "https://example.com/", + "deployment_type": "gpu", + "resources": { + "requests": {"cpu": "6", "memory": "16G"}, + "limits": {"cpu": "12", "memory": "32G"}, + }, + "gpu_resources": { + "requests": {"nvidia.com/gpu": "1"}, + "limits": {"nvidia.com/gpu": "1"}, + }, + "inference_container_image": "", + **overrides, + } + payload = ( + overrides if replace else {**DEFAULT_ONBOARD_INFERENCE_MODEL, **overrides} + ) + + # QUERY + r = gateway.post("/v2/amo-tasks", json=payload) + log.info( + "\nPOST /v2/amo-tasks -> (%s)\nPayload\n%s \nResponse (redacted)(%s)", + r.status_code, + payload, + redacted_response_text(r), + ) + + body = r.json() + assert isinstance(body, dict), type(body) + return body, r + + return _onboard_inference_model + + +# ======================== Tests ================================= +# --------[create_model] Tests ----------------------------------- +def test_create_model_fixture(create_model, caplog): + caplog.set_level(logging.INFO, logger="gateway_tests") + + body = create_model(payloads.SANDBOX_MODEL) + assert "PENDING" == body["status"] + + +# --------[list_models] Tests ------------------------------------ +def test_list_models_fixture(list_models, caplog): + caplog.set_level(logging.INFO, logger="gateway_tests") + body = list_models(limit=1) + assert "total_records" in body + assert isinstance(body["results"], list) + + +# --------[deploy_model_with_amo] Tests -------------------------- +def test_deploy_model_with_amo_fixture_no_urls( + deploy_model_with_amo, list_models, caplog +): + # PARAMS + list_model_body = list_models() + model_id = list_model_body["results"][0][ + "id" + ] # [TODO] Check different models for errors + + params = { + "fine_tuned_model_id": model_id, + "model_configs_url": "", # Intentional missing url for testing + "model_checkpoint_url": "", # Intentional missing url for testing + } + + caplog.set_level(logging.INFO, logger="gateway_tests") + body, r = deploy_model_with_amo(model_id, **params) + assert isinstance(body, dict) + assert r.status_code == 422 + assert "detail" in body and isinstance(body["detail"], list) + assert "both model_checkpoint_url" in body["detail"][0]["msg"].lower() + assert "model_config" in body["detail"][0]["msg"].lower() + assert "should be" in body["detail"][0]["msg"].lower() + + +def test_deploy_model_with_amo_fixture_no_urls_and_missing_model_id( + deploy_model_with_amo, caplog +): + # PARAMS + # Sample nonexistent model_id + model_id = "e436969b-24bf-46e5-ac6f-7d0653da0f14" + + params = { + "fine_tuned_model_id": model_id, + "model_configs_url": "", # Intentional missing url for testing + "model_checkpoint_url": "", # Intentional missing url for testing + } + + caplog.set_level(logging.INFO, logger="gateway_tests") + body, r = deploy_model_with_amo(model_id, **params) + assert isinstance(body, dict) + assert r.status_code == 404 + assert "Model not found" in body["detail"] + + +# --------[update_model] Tests ----------------------------------- +def test_update_model_display_name(name_factory, update_model, list_models, caplog): + # PARAMS + list_model_body = list_models() + model_id = list_model_body["results"][0]["id"] + + caplog.set_level(logging.INFO, logger="gateway_tests") + + # Payload + display_name = name_factory(base="integration-test") + + payload = {"display_name": display_name} + + body, r = update_model(model_id, replace=True, **payload) + assert isinstance(body, dict) + assert r.status_code == 201 + assert display_name == body["display_name"] + + +def test_update_model_display_name_for_missing_model_id( + name_factory, update_model, caplog +): + # PARAMS + # Sample nonexistent model_id + model_id = "e436969b-24bf-46e5-ac6f-7d0653da0f14" + + caplog.set_level(logging.INFO, logger="gateway_tests") + + # Payload + display_name = name_factory(base="integration-test") + + payload = {"display_name": display_name} + + body, r = update_model(model_id, replace=True, **payload) + assert isinstance(body, dict) + assert r.status_code == 404 + assert "Model not found" == body["detail"] + + +# --------[get_model] Tests -------------------------------------- +def test_get_model_(get_model, list_models, caplog): + caplog.set_level(logging.INFO, logger="gateway_tests") + + # PARAMS + list_model_body = list_models() + model_id = list_model_body["results"][0]["id"] + + body, r = get_model(model_id) + assert isinstance(body, dict) + assert r.status_code == 200 + + +def test_get_model_with_nonexisting_model_id(get_model, caplog): + caplog.set_level(logging.INFO, logger="gateway_tests") + + # PARAMS + # Sample nonexistent model_id + model_id = "e436969b-24bf-46e5-ac6f-7d0653da0f14" + + body, r = get_model(model_id) + assert isinstance(body, dict) + assert r.status_code == 404 + assert "Model not found" == body["detail"] + + +# --------[delete_model] Tests ----------------------------------- +def test_delete_model_create_then_delete( + name_factory, create_model, delete_model, caplog +): + caplog.set_level(logging.INFO, logger="gateway_tests") + + create_model_body = create_model(payloads.SANDBOX_MODEL) + model_id = create_model_body["id"] + delete_model_r = delete_model(model_id) + assert delete_model_r.status_code == 204 + + +# --------[retrieve_amo_task](2) Tests ------------------------------ +def test_retrieve_amo_task_error( + amo_tests_activate, retrieve_amo_task, list_models, caplog +): + caplog.set_level(logging.INFO, logger="gateway_tests") + + # PARAMS + list_model_body = list_models() + model_id = list_model_body["results"][0]["id"] + + body, r = retrieve_amo_task(model_id) + assert r.status_code == 422 + assert "Model ID must not exceed 30 characters." == body["detail"] + + +def test_retrieve_amo_task_onboard_inference_model_then_retrieve_amo_task( + amo_tests_activate, name_factory, onboard_inference_model, retrieve_amo_task, caplog +): + caplog.set_level(logging.INFO, logger="gateway_tests") + + # SETUP + # ---payload + model_name = name_factory(base="integration-test") + model_id = name_factory(base="test") + model_id_amo_compatible = model_id.replace("_", "-")[:30] + payload_overrides = { + "model_name": model_name, + "model_id": model_id_amo_compatible, # confusing as model_id has 2 meanings + } + # --- fixture query + body, r = onboard_inference_model(**payload_overrides) + + # FIXTURE QUERY + body, r = retrieve_amo_task(model_id_amo_compatible) + + # TEST + assert r.status_code == 200 + assert f"amo-{model_id_amo_compatible}" == body["model_id"] + + +# --------[offboard_inference_model] Tests ----------------------- +def test_onboard_inference_model_onboard_then_offboard_inference_model( + amo_tests_activate, + name_factory, + onboard_inference_model, + offboard_inference_model, + caplog, +): + caplog.set_level(logging.INFO, logger="gateway_tests") + + # SETUP + # ---payload + model_name = name_factory(base="integration-test") + model_id = name_factory(base="test") + model_id_amo_compatible = model_id.replace("_", "-")[:30] + payload_overrides = { + "model_name": model_name, + "model_id": model_id_amo_compatible, # confusing as model_id has 2 meanings + } + # --- fixture query + body, r = onboard_inference_model(**payload_overrides) + + # FIXTURE QUERY + offboard_inference_model_body, offboard_inference_model_r = ( + offboard_inference_model(model_id_amo_compatible) + ) + + # TEST + assert offboard_inference_model_r.status_code == 200 + assert ( + "Model offboarding request submitted" + == offboard_inference_model_body["message"] + ) + + +# --------[onboard_inference_model] Tests ------------------------ +def test_onboard_inference_model_( + amo_tests_activate, + name_factory, + onboard_inference_model, + offboard_inference_model, + caplog, +): + caplog.set_level(logging.INFO, logger="gateway_tests") + + # PAYLOAD + model_name = name_factory(base="integration-test") + model_id = name_factory(base="test") + model_id_amo_compatible = model_id.replace("_", "-")[:30] + payload_overrides = { + "model_name": model_name, + "model_id": model_id_amo_compatible, # confusing as model_id has 2 meanings + } + + # FIXTURE QUERY + body, r = onboard_inference_model(**payload_overrides) + + # TEST + assert r.status_code == 200 + + # TEARDOWN + # FIXTURE QUERY + offboard_inference_model_body, offboard_inference_model_r = ( + offboard_inference_model(model_id_amo_compatible) + ) + + # TEST + assert offboard_inference_model_r.status_code == 200 + assert ( + "Model offboarding request submitted" + == offboard_inference_model_body["message"] + ) diff --git a/tests/integration/utils.py b/tests/integration/utils.py new file mode 100644 index 00000000..7ee403e8 --- /dev/null +++ b/tests/integration/utils.py @@ -0,0 +1,106 @@ +import json +import re +from datetime import datetime, timezone +from typing import Any, Dict + +__all__ = [ + "REDACT_KEYS", + "mask_secret_string", + "redact_obj", + "redacted_response_text", +] + +# redact these common secret fields if present +REDACT_KEYS = { + "value", + "token", + "access_token", + "api_token", + "apiKey", + "api_key", + "secret", + "password", +} + + +def mask_secret_string(s: str) -> str: + """Redact keys and replace api_keys with pak-****""" + if not isinstance(s, str) or not s: + return "" + # mask api_keys with pak-**** + if s.startswith("pak-"): + return "pak-****" + # mask other auth tokens (JWT-ish) + if re.fullmatch(r"eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+", s): + return "***TOKEN***" + return "****" + + +def redact_obj(obj: Any) -> Any: + """Redact dict/list/str recursively, preserving structure.""" + if isinstance(obj, dict): + out: Dict[str, Any] = {} + for k, v in obj.items(): + if k in REDACT_KEYS and isinstance(v, str): + out[k] = mask_secret_string(v) + else: + out[k] = redact_obj(v) + return out + if isinstance(obj, list): + return [redact_obj(x) for x in obj] + if isinstance(obj, str): + s = obj + # mask api_keys with pak-**** + s = re.sub(r"\bpak-[A-Za-z0-9]+\b", "pak-****", s) + # mask other auth tokens (JWT-ish) + s = re.sub( + r"\beyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b", + "***TOKEN***", + s, + ) + return s + return obj + + +def redacted_response_text(r) -> str: + """Render a response with secrets redacted. + Tries JSON first; falls back to plaintext masking. + """ + try: + data = r.json() + return json.dumps(redact_obj(data), indent=2, ensure_ascii=False) + except Exception: + txt = getattr(r, "text", "") or "" + txt = re.sub(r"\bpak-[A-Za-z0-9]+\b", "pak-****", txt) + txt = re.sub( + r"\beyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b", + "***TOKEN***", + txt, + ) + return txt + + +def make_timestamped_name( + base: str, + *, + prefix: str | None = None, + ext: str = "", + now: datetime | None = None, +) -> str: + """ + Build: --__ + """ + now = now or datetime.now(timezone.utc) + epoch6 = str(int(now.timestamp()))[:6] + yyyymmdd = now.strftime("%Y%m%d") + hhmmss = now.strftime("%H%M%S") + + def norm(s: str) -> str: + return s.strip().replace(" ", "_") + + parts = [] + if prefix: + parts.append(norm(prefix)) + parts.append(norm(base)) + stem = "-".join(parts) + f"-{epoch6}_{yyyymmdd}_{hhmmss}" + return stem + (ext or "")