Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions compose_api/api/client/api/results/get_simulations_status_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from http import HTTPStatus
from typing import Any, Optional, Union, cast

import httpx

from ...client import AuthenticatedClient, Client
from ...types import Response, UNSET
from ... import errors

from ...models.hpc_run import HpcRun
from ...models.http_validation_error import HTTPValidationError
from typing import cast


def _get_kwargs(
*,
body: list[int],
) -> dict[str, Any]:
headers: dict[str, Any] = {}

_kwargs: dict[str, Any] = {
"method": "get",
"url": "/results/simulations/status/batch",
}

_kwargs["json"] = body

headers["Content-Type"] = "application/json"

_kwargs["headers"] = headers
return _kwargs


def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]:
if response.status_code == 200:
response_200 = []
_response_200 = response.json()
for response_200_item_data in _response_200:
response_200_item = HpcRun.from_dict(response_200_item_data)

response_200.append(response_200_item)

return response_200
if response.status_code == 422:
response_422 = HTTPValidationError.from_dict(response.json())

return response_422
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None


def _build_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Response[Union[HTTPValidationError, list["HpcRun"]]]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)


def sync_detailed(
*,
client: Union[AuthenticatedClient, Client],
body: list[int],
) -> Response[Union[HTTPValidationError, list["HpcRun"]]]:
"""Get simulation status records for a list of IDs

Args:
body (list[int]):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[Union[HTTPValidationError, list['HpcRun']]]
"""

kwargs = _get_kwargs(
body=body,
)

response = client.get_httpx_client().request(
**kwargs,
)

return _build_response(client=client, response=response)


def sync(
*,
client: Union[AuthenticatedClient, Client],
body: list[int],
) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]:
"""Get simulation status records for a list of IDs

Args:
body (list[int]):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Union[HTTPValidationError, list['HpcRun']]
"""

return sync_detailed(
client=client,
body=body,
).parsed


async def asyncio_detailed(
*,
client: Union[AuthenticatedClient, Client],
body: list[int],
) -> Response[Union[HTTPValidationError, list["HpcRun"]]]:
"""Get simulation status records for a list of IDs

Args:
body (list[int]):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[Union[HTTPValidationError, list['HpcRun']]]
"""

kwargs = _get_kwargs(
body=body,
)

response = await client.get_async_httpx_client().request(**kwargs)

return _build_response(client=client, response=response)


async def asyncio(
*,
client: Union[AuthenticatedClient, Client],
body: list[int],
) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]:
"""Get simulation status records for a list of IDs

Args:
body (list[int]):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Union[HTTPValidationError, list['HpcRun']]
"""

return (
await asyncio_detailed(
client=client,
body=body,
)
).parsed
5 changes: 5 additions & 0 deletions compose_api/api/client/models/job_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@


class JobStatus(str, Enum):
CANCELLED = "cancelled"
COMPLETED = "completed"
FAILED = "failed"
OUT_OF_MEMORY = "out_of_memory"
PENDING = "pending"
QUEUED = "queued"
RUNNING = "running"
SUSPENDED = "suspended"
TIMEOUT = "timeout"
UNKNOWN = "unknown"
WAITING = "waiting"

def __str__(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions compose_api/api/client/models/registered_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import cast

if TYPE_CHECKING:
from ..models.bi_graph_process import BiGraphProcess
from ..models.bi_graph_step import BiGraphStep
from ..models.bi_graph_process import BiGraphProcess


T = TypeVar("T", bound="RegisteredPackage")
Expand All @@ -36,8 +36,8 @@ class RegisteredPackage:
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> dict[str, Any]:
from ..models.bi_graph_process import BiGraphProcess
from ..models.bi_graph_step import BiGraphStep
from ..models.bi_graph_process import BiGraphProcess

database_id = self.database_id

Expand Down Expand Up @@ -69,8 +69,8 @@ def to_dict(self) -> dict[str, Any]:

@classmethod
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
from ..models.bi_graph_process import BiGraphProcess
from ..models.bi_graph_step import BiGraphStep
from ..models.bi_graph_process import BiGraphProcess

d = dict(src_dict)
database_id = d.pop("database_id")
Expand Down
6 changes: 3 additions & 3 deletions compose_api/api/client/models/simulator_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import datetime

if TYPE_CHECKING:
from ..models.registered_package import RegisteredPackage
from ..models.containerization_file_repr import ContainerizationFileRepr
from ..models.registered_package import RegisteredPackage


T = TypeVar("T", bound="SimulatorVersion")
Expand All @@ -40,8 +40,8 @@ class SimulatorVersion:
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> dict[str, Any]:
from ..models.registered_package import RegisteredPackage
from ..models.containerization_file_repr import ContainerizationFileRepr
from ..models.registered_package import RegisteredPackage

singularity_def = self.singularity_def.to_dict()

Expand Down Expand Up @@ -82,8 +82,8 @@ def to_dict(self) -> dict[str, Any]:

@classmethod
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
from ..models.registered_package import RegisteredPackage
from ..models.containerization_file_repr import ContainerizationFileRepr
from ..models.registered_package import RegisteredPackage

d = dict(src_dict)
singularity_def = ContainerizationFileRepr.from_dict(d.pop("singularity_def"))
Expand Down
19 changes: 19 additions & 0 deletions compose_api/api/routers/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
config = RouterConfig(router=APIRouter(), prefix="/results", dependencies=[])


@config.router.get(
path="/simulations/status/batch",
response_model=list[HpcRun],
operation_id="get-simulations-status-batch",
tags=["Results"],
dependencies=[Depends(get_database_service)],
summary="Get simulation status records for a list of IDs",
)
async def get_simulations_status_batch(ids: list[int]) -> list[HpcRun]:
db_service = get_database_service()
if db_service is None:
raise HTTPException(status_code=500, detail="Database service is not initialized")
try:
return await db_service.get_hpc_db().get_hpcruns_by_refs(ref_ids=ids, job_type=JobType.SIMULATION)
except Exception as e:
logger.exception(f"Error fetching batch simulation statuses for ids: {ids}.")
raise HTTPException(status_code=500, detail=str(e)) from e


@config.router.get(
path="/simulation/status",
response_model=HpcRun,
Expand Down
38 changes: 37 additions & 1 deletion compose_api/api/spec/openapi_3_1_0_generated.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
openapi: 3.1.0
info:
title: compose-api
version: 0.3.3
version: 0.3.8
paths:
/curated/copasi:
post:
Expand Down Expand Up @@ -124,6 +124,37 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/results/simulations/status/batch:
get:
tags:
- Results
summary: Get simulation status records for a list of IDs
operationId: get-simulations-status-batch
requestBody:
content:
application/json:
schema:
items:
type: integer
type: array
title: Ids
required: true
responses:
'200':
description: Successful Response
content:
application/json:
schema:
items:
$ref: '#/components/schemas/HpcRun'
type: array
title: Response Get-Simulations-Status-Batch
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/results/simulation/status:
get:
tags:
Expand Down Expand Up @@ -449,6 +480,11 @@ components:
- completed
- failed
- pending
- cancelled
- out_of_memory
- suspended
- timeout
- unknown
title: JobStatus
JobType:
type: string
Expand Down
13 changes: 13 additions & 0 deletions compose_api/db/services/hpc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ async def insert_hpcrun(self, slurmjobid: int, job_type: JobType, ref_id: int, c
async def get_hpcrun_by_ref(self, ref_id: int, job_type: JobType) -> HpcRun | None:
pass

@abstractmethod
async def get_hpcruns_by_refs(self, ref_ids: list[int], job_type: JobType) -> list[HpcRun]:
pass

@abstractmethod
async def get_hpcrun_by_slurmjobid(self, slurmjobid: int) -> HpcRun | None:
pass
Expand Down Expand Up @@ -135,6 +139,15 @@ async def insert_hpcrun(self, slurmjobid: int, job_type: JobType, ref_id: int, c
await session.flush()
return orm_hpc_run.to_hpc_run()

@override
async def get_hpcruns_by_refs(self, ref_ids: list[int], job_type: JobType) -> list[HpcRun]:
async with self.async_session_maker() as session, session.begin():
run_type = self._get_job_type_ref(job_type)
stmt = select(ORMHpcRun).where(run_type.in_(ref_ids))
result: Result[tuple[ORMHpcRun]] = await session.execute(stmt)
orm_hpcruns = result.scalars().all()
return [orm_hpcrun.to_hpc_run() for orm_hpcrun in orm_hpcruns]

@override
async def get_hpcrun_by_slurmjobid(self, slurmjobid: int) -> HpcRun | None:
async with self.async_session_maker() as session, session.begin():
Expand Down
Loading
Loading