Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/477.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a background task to update the progress reporter with 'KernelPullProgressEvent', until image pulling is done.
65 changes: 64 additions & 1 deletion src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import trafaret as t
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from ..background import ProgressReporter

from ai.backend.common import redis, validators as tx
from ai.backend.common.docker import ImageRef
Expand All @@ -72,6 +73,7 @@
SessionStartedEvent,
SessionSuccessEvent,
SessionTerminatedEvent,
KernelPullProgressEvent,
)
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
Expand Down Expand Up @@ -513,7 +515,62 @@ async def _create(request: web.Request, params: Any) -> web.Response:
resp['servicePorts'] = []
resp['created'] = True

if not params['enqueue_only']:
async def monitor_kernel_preparation(reporter: ProgressReporter) -> None:
Copy link
Member

@achimnol achimnol Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we extract this as a separate function for better testability? (Let's avoid too much nested functions!)

progress = [0, 0]

async def _get_status(kernel_id):
async with root_ctx.db.begin_readonly() as conn:
query = (
sa.select([
kernels.c.id,
kernels.c.status,
])
.select_from(kernels)
.where(kernels.c.id == kernel_id)
)
result = await conn.execute(query)

return result.first()

async def _update_progress(
app: web.Application,
source: AgentId,
event: KernelPullProgressEvent,
) -> None:
# update both current and total
progress[0] = int(event.current_progress)
progress[1] = int(event.total_progress)

progress_handler = root_ctx.event_dispatcher.subscribe(
KernelPullProgressEvent,
request.app,
_update_progress
)
kernel_id = resp['sessionId']
try:
while True:
result = await _get_status(kernel_id)
if result is None:
continue
if result['status'] == KernelStatus.PREPARING:
await reporter.update(0)
if result['status'] == KernelStatus.RUNNING:
break
reporter.current_progress = progress[0]
reporter.total_progress = progress[1]
await reporter.update(0)
await asyncio.sleep(0.5)
finally:
root_ctx.event_dispatcher.unsubscribe(progress_handler)

if params['enqueue_only']:
task_id = await root_ctx.background_task_manager.start(
monitor_kernel_preparation,
name='monitor-kernel-preparation',
)
resp['background_task'] = str(task_id)
return web.json_response(resp, status=201)
else:
app_ctx.pending_waits.add(current_task)
max_wait = params['max_wait_seconds']
try:
Expand All @@ -523,7 +580,13 @@ async def _create(request: web.Request, params: Any) -> web.Response:
else:
await start_event.wait()
except asyncio.TimeoutError:
task_id = await root_ctx.background_task_manager.start(
monitor_kernel_preparation,
name='monitor-kernel-preparation',
)
resp['background_task'] = str(task_id)
resp['status'] = 'TIMEOUT'
return web.json_response(resp, status=201)
else:
await asyncio.sleep(0.5)
async with root_ctx.db.begin_readonly() as conn:
Expand Down