diff --git a/compose_api/api/client/api/results/get_simulations_status_batch.py b/compose_api/api/client/api/results/get_simulations_status_batch.py new file mode 100644 index 0000000..347ee88 --- /dev/null +++ b/compose_api/api/client/api/results/get_simulations_status_batch.py @@ -0,0 +1,170 @@ +from http import HTTPStatus +from typing import Any, Optional, Union, cast + +import httpx + +from ...client import AuthenticatedClient, Client +from ...types import Response, UNSET +from ... import errors + +from ...models.hpc_run import HpcRun +from ...models.http_validation_error import HTTPValidationError +from typing import cast + + +def _get_kwargs( + *, + body: list[int], +) -> dict[str, Any]: + headers: dict[str, Any] = {} + + _kwargs: dict[str, Any] = { + "method": "get", + "url": "/results/simulations/status/batch", + } + + _kwargs["json"] = body + + headers["Content-Type"] = "application/json" + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]: + if response.status_code == 200: + response_200 = [] + _response_200 = response.json() + for response_200_item_data in _response_200: + response_200_item = HpcRun.from_dict(response_200_item_data) + + response_200.append(response_200_item) + + return response_200 + if response.status_code == 422: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[HTTPValidationError, list["HpcRun"]]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: list[int], +) -> Response[Union[HTTPValidationError, list["HpcRun"]]]: + """Get simulation status records for a list of IDs + + Args: + body (list[int]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, list['HpcRun']]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + body: list[int], +) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]: + """Get simulation status records for a list of IDs + + Args: + body (list[int]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, list['HpcRun']] + """ + + return sync_detailed( + client=client, + body=body, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: list[int], +) -> Response[Union[HTTPValidationError, list["HpcRun"]]]: + """Get simulation status records for a list of IDs + + Args: + body (list[int]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, list['HpcRun']]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + body: list[int], +) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]: + """Get simulation status records for a list of IDs + + Args: + body (list[int]): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, list['HpcRun']] + """ + + return ( + await asyncio_detailed( + client=client, + body=body, + ) + ).parsed diff --git a/compose_api/api/client/models/job_status.py b/compose_api/api/client/models/job_status.py index 1516c66..1557e2c 100644 --- a/compose_api/api/client/models/job_status.py +++ b/compose_api/api/client/models/job_status.py @@ -2,11 +2,16 @@ class JobStatus(str, Enum): + CANCELLED = "cancelled" COMPLETED = "completed" FAILED = "failed" + OUT_OF_MEMORY = "out_of_memory" PENDING = "pending" QUEUED = "queued" RUNNING = "running" + SUSPENDED = "suspended" + TIMEOUT = "timeout" + UNKNOWN = "unknown" WAITING = "waiting" def __str__(self) -> str: diff --git a/compose_api/api/client/models/registered_package.py b/compose_api/api/client/models/registered_package.py index eb0dfa6..ff92623 100644 --- a/compose_api/api/client/models/registered_package.py +++ b/compose_api/api/client/models/registered_package.py @@ -10,8 +10,8 @@ from typing import cast if TYPE_CHECKING: - from ..models.bi_graph_process import BiGraphProcess from ..models.bi_graph_step import BiGraphStep + from ..models.bi_graph_process import BiGraphProcess T = TypeVar("T", bound="RegisteredPackage") @@ -36,8 +36,8 @@ class RegisteredPackage: additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: - from ..models.bi_graph_process import BiGraphProcess from ..models.bi_graph_step import BiGraphStep + from ..models.bi_graph_process import BiGraphProcess database_id = self.database_id @@ -69,8 +69,8 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: - from ..models.bi_graph_process import BiGraphProcess from ..models.bi_graph_step import BiGraphStep + from ..models.bi_graph_process import BiGraphProcess d = dict(src_dict) database_id = d.pop("database_id") diff --git a/compose_api/api/client/models/simulator_version.py b/compose_api/api/client/models/simulator_version.py index a8a2e9d..9112b8e 100644 --- a/compose_api/api/client/models/simulator_version.py +++ b/compose_api/api/client/models/simulator_version.py @@ -14,8 +14,8 @@ import datetime if TYPE_CHECKING: - from ..models.registered_package import RegisteredPackage from ..models.containerization_file_repr import ContainerizationFileRepr + from ..models.registered_package import RegisteredPackage T = TypeVar("T", bound="SimulatorVersion") @@ -40,8 +40,8 @@ class SimulatorVersion: additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) def to_dict(self) -> dict[str, Any]: - from ..models.registered_package import RegisteredPackage from ..models.containerization_file_repr import ContainerizationFileRepr + from ..models.registered_package import RegisteredPackage singularity_def = self.singularity_def.to_dict() @@ -82,8 +82,8 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: - from ..models.registered_package import RegisteredPackage from ..models.containerization_file_repr import ContainerizationFileRepr + from ..models.registered_package import RegisteredPackage d = dict(src_dict) singularity_def = ContainerizationFileRepr.from_dict(d.pop("singularity_def")) diff --git a/compose_api/api/routers/results.py b/compose_api/api/routers/results.py index 540439c..3bb9e6d 100644 --- a/compose_api/api/routers/results.py +++ b/compose_api/api/routers/results.py @@ -25,6 +25,25 @@ config = RouterConfig(router=APIRouter(), prefix="/results", dependencies=[]) +@config.router.get( + path="/simulations/status/batch", + response_model=list[HpcRun], + operation_id="get-simulations-status-batch", + tags=["Results"], + dependencies=[Depends(get_database_service)], + summary="Get simulation status records for a list of IDs", +) +async def get_simulations_status_batch(ids: list[int]) -> list[HpcRun]: + db_service = get_database_service() + if db_service is None: + raise HTTPException(status_code=500, detail="Database service is not initialized") + try: + return await db_service.get_hpc_db().get_hpcruns_by_refs(ref_ids=ids, job_type=JobType.SIMULATION) + except Exception as e: + logger.exception(f"Error fetching batch simulation statuses for ids: {ids}.") + raise HTTPException(status_code=500, detail=str(e)) from e + + @config.router.get( path="/simulation/status", response_model=HpcRun, diff --git a/compose_api/api/spec/openapi_3_1_0_generated.yaml b/compose_api/api/spec/openapi_3_1_0_generated.yaml index ca330ee..faf2506 100644 --- a/compose_api/api/spec/openapi_3_1_0_generated.yaml +++ b/compose_api/api/spec/openapi_3_1_0_generated.yaml @@ -1,7 +1,7 @@ openapi: 3.1.0 info: title: compose-api - version: 0.3.3 + version: 0.3.8 paths: /curated/copasi: post: @@ -124,6 +124,37 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /results/simulations/status/batch: + get: + tags: + - Results + summary: Get simulation status records for a list of IDs + operationId: get-simulations-status-batch + requestBody: + content: + application/json: + schema: + items: + type: integer + type: array + title: Ids + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + items: + $ref: '#/components/schemas/HpcRun' + type: array + title: Response Get-Simulations-Status-Batch + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /results/simulation/status: get: tags: @@ -449,6 +480,11 @@ components: - completed - failed - pending + - cancelled + - out_of_memory + - suspended + - timeout + - unknown title: JobStatus JobType: type: string diff --git a/compose_api/db/services/hpc_db.py b/compose_api/db/services/hpc_db.py index 3593fed..1142b58 100644 --- a/compose_api/db/services/hpc_db.py +++ b/compose_api/db/services/hpc_db.py @@ -46,6 +46,10 @@ async def insert_hpcrun(self, slurmjobid: int, job_type: JobType, ref_id: int, c async def get_hpcrun_by_ref(self, ref_id: int, job_type: JobType) -> HpcRun | None: pass + @abstractmethod + async def get_hpcruns_by_refs(self, ref_ids: list[int], job_type: JobType) -> list[HpcRun]: + pass + @abstractmethod async def get_hpcrun_by_slurmjobid(self, slurmjobid: int) -> HpcRun | None: pass @@ -135,6 +139,15 @@ async def insert_hpcrun(self, slurmjobid: int, job_type: JobType, ref_id: int, c await session.flush() return orm_hpc_run.to_hpc_run() + @override + async def get_hpcruns_by_refs(self, ref_ids: list[int], job_type: JobType) -> list[HpcRun]: + async with self.async_session_maker() as session, session.begin(): + run_type = self._get_job_type_ref(job_type) + stmt = select(ORMHpcRun).where(run_type.in_(ref_ids)) + result: Result[tuple[ORMHpcRun]] = await session.execute(stmt) + orm_hpcruns = result.scalars().all() + return [orm_hpcrun.to_hpc_run() for orm_hpcrun in orm_hpcruns] + @override async def get_hpcrun_by_slurmjobid(self, slurmjobid: int) -> HpcRun | None: async with self.async_session_maker() as session, session.begin(): diff --git a/tests/simulation/test_batch_monitor.py b/tests/simulation/test_batch_monitor.py new file mode 100644 index 0000000..fbe0644 --- /dev/null +++ b/tests/simulation/test_batch_monitor.py @@ -0,0 +1,130 @@ +import asyncio +import os + +import pytest + +from compose_api.api.client import Client +from compose_api.api.client.api.curated import run_copasi +from compose_api.api.client.api.results import get_simulations_status_batch +from compose_api.api.client.models import BodyRunCopasi, HpcRun, JobStatus, SimulationExperiment +from compose_api.api.client.types import File +from compose_api.config import get_settings +from compose_api.db.database_service import DatabaseService, DatabaseServiceSQL +from compose_api.simulation.data_service import DataService +from compose_api.simulation.job_monitor import JobMonitor +from compose_api.simulation.models import JobType, SimulationRequest, Simulator, SimulatorVersion +from compose_api.simulation.simulation_service import SimulationServiceHpc +from tests.simulators.utils import get_results_and_compare_copasi, test_dir + + +@pytest.mark.asyncio +async def test_get_simulations_status_batch( + in_memory_api_client: Client, + database_service: DatabaseService, + simulation_request: SimulationRequest, + simulator: SimulatorVersion, +) -> None: + sim_a = await database_service.get_simulator_db().insert_simulation( + sim_request=simulation_request, + experiment_id="test_batch_experiment_a", + simulator_version=simulator, + ) + sim_b = await database_service.get_simulator_db().insert_simulation( + sim_request=simulation_request, + experiment_id="test_batch_experiment_b", + simulator_version=simulator, + ) + + hpcrun_a = await database_service.get_hpc_db().insert_hpcrun( + slurmjobid=1001, + job_type=JobType.SIMULATION, + ref_id=sim_a.database_id, + correlation_id="corr_a", + ) + hpcrun_b = await database_service.get_hpc_db().insert_hpcrun( + slurmjobid=1002, + job_type=JobType.SIMULATION, + ref_id=sim_b.database_id, + correlation_id="corr_b", + ) + + try: + response = await get_simulations_status_batch.asyncio_detailed( + client=in_memory_api_client, body=[sim_a.database_id, sim_b.database_id] + ) + + assert response.status_code == 200 + results = response.parsed + assert isinstance(results, list) + assert len(results) == 2 + + returned_sim_ids = {r.sim_id for r in results} + assert sim_a.database_id in returned_sim_ids + assert sim_b.database_id in returned_sim_ids + + returned_slurm_ids = {r.slurmjobid for r in results} + assert hpcrun_a.slurmjobid in returned_slurm_ids + assert hpcrun_b.slurmjobid in returned_slurm_ids + finally: + await database_service.get_hpc_db().delete_hpcrun(hpcrun_a.database_id) + await database_service.get_hpc_db().delete_hpcrun(hpcrun_b.database_id) + await database_service.get_simulator_db().delete_simulation(sim_a.database_id) + await database_service.get_simulator_db().delete_simulation(sim_b.database_id) + + +@pytest.mark.skipif(len(get_settings().slurm_submit_key_path) == 0, reason="slurm ssh key file not supplied") +@pytest.mark.asyncio +async def test_batch_status_after_copasi_runs( + in_memory_api_client: Client, + database_service: DatabaseServiceSQL, + simulation_service_slurm: SimulationServiceHpc, + job_monitor: JobMonitor, + data_service: DataService, + simulator: Simulator, +) -> None: + copasi_sbml = os.path.join(test_dir, "fixtures/resources/copasi.sbml") + experiments: list[SimulationExperiment] = [] + with open(copasi_sbml, "rb") as f: + for _ in range(2): + f.seek(0) + sim_experiment = await run_copasi.asyncio( + client=in_memory_api_client, + start_time=0, + duration=10, + num_data_points=51, + body=BodyRunCopasi(sbml=File(file_name="copasi.sbml", payload=f)), + ) + assert isinstance(sim_experiment, SimulationExperiment) + experiments.append(sim_experiment) + + sim_ids = [exp.simulation_database_id for exp in experiments] + terminal_statuses = { + JobStatus.COMPLETED, + JobStatus.FAILED, + JobStatus.CANCELLED, + JobStatus.TIMEOUT, + JobStatus.OUT_OF_MEMORY, + } + num_loops = 0 + + results: list[HpcRun] = [] + while num_loops < 60: + response = await get_simulations_status_batch.asyncio_detailed( + client=in_memory_api_client, + body=sim_ids, + ) + assert response.status_code == 200 + assert isinstance(response.parsed, list) + results = response.parsed + if all(r.status in terminal_statuses for r in results): + break + + await asyncio.sleep(2) + num_loops += 1 + + assert len(results) == len(experiments) + + for result in results: + assert result.sim_id in sim_ids + assert result.status == JobStatus.COMPLETED, f"Simulation {result.sim_id} ended with status {result.status}" + await get_results_and_compare_copasi(api_client=in_memory_api_client, sim_id=result.sim_id) diff --git a/tests/simulators/test_copasi.py b/tests/simulators/test_copasi.py index bc1ffc0..1855f73 100644 --- a/tests/simulators/test_copasi.py +++ b/tests/simulators/test_copasi.py @@ -1,6 +1,4 @@ import os -import tempfile -from pathlib import Path import pytest @@ -14,7 +12,10 @@ from compose_api.simulation.job_monitor import JobMonitor from compose_api.simulation.models import Simulator from compose_api.simulation.simulation_service import SimulationServiceHpc -from tests.simulators.utils import assert_test_sim_results, check_experiment_run, test_dir +from tests.simulators.utils import ( + check_experiment_run, + get_results_and_compare_copasi, +) @pytest.mark.skipif(len(get_settings().slurm_submit_key_path) == 0, reason="slurm ssh key file not supplied") @@ -38,16 +39,4 @@ async def test_copasi( ) results = await check_experiment_run(sim_experiment=sim_experiment, in_memory_api_client=in_memory_api_client) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_dir_path = Path(temp_dir) - experiment_results = temp_dir_path / Path("experiment_results.zip") - with open(experiment_results, "wb") as results_file: - results_file.write(results.content) - report_csv_file = Path(os.path.join(test_dir, "fixtures/resources/report.csv")) - assert_test_sim_results( - archive_results=experiment_results, - expected_csv_path=report_csv_file, - temp_dir=temp_dir_path, - difference_tolerance=1e-4, - ) + await get_results_and_compare_copasi(api_client=in_memory_api_client, file_result=results) diff --git a/tests/simulators/utils.py b/tests/simulators/utils.py index 9e6925a..2a62a2c 100644 --- a/tests/simulators/utils.py +++ b/tests/simulators/utils.py @@ -1,6 +1,8 @@ import asyncio import math import os +import tempfile +from pathlib import Path from typing import Any from zipfile import ZipFile @@ -87,4 +89,22 @@ def assert_test_sim_results( assert report_val == experiment_val # Must be string portion of report then (columns) +async def get_results_and_compare_copasi(api_client: Client, sim_id: int = 0, file_result: Any = None) -> None: + if file_result is None: + file_result = await get_simulation_results_file.asyncio_detailed(client=api_client, simulation_id=sim_id) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + experiment_results = temp_dir_path / Path("experiment_results.zip") + with open(experiment_results, "wb") as results_file: + results_file.write(file_result.content) + report_csv_file = Path(os.path.join(test_dir, "fixtures/resources/report.csv")) + assert_test_sim_results( + archive_results=experiment_results, + expected_csv_path=report_csv_file, + temp_dir=temp_dir_path, + difference_tolerance=1e-4, + ) + + test_dir = os.path.dirname(__file__).rsplit("/", 1)[0]