Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.
1 change: 1 addition & 0 deletions changes/442.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add API to get available resources
207 changes: 126 additions & 81 deletions src/ai/backend/manager/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
from ai.backend.common.types import DefaultForUnspecified, ResourceSlot

from ..models import (
agents, resource_presets,
resource_presets,
domains, groups, kernels, users,
AgentStatus,
association_groups_users,
query_allowed_sgroups,
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
get_groups_info_by_row,
check_scaling_group_resource,
check_group_resource,
RESOURCE_USAGE_KERNEL_STATUSES, LIVE_STATUS,
)
from .auth import auth_required, superadmin_required
Expand All @@ -64,7 +63,7 @@ async def list_presets(request: web.Request) -> web.Response:
"""
Returns the list of all resource presets.
"""
log.info('LIST_PRESETS (ak:{})', request['keypair']['access_key'])
log.info('RESOURCE.LIST_PRESETS (ak:{})', request['keypair']['access_key'])
root_ctx: RootContext = request.app['_root.context']
await root_ctx.shared_config.get_resource_slots()
async with root_ctx.db.begin_readonly() as conn:
Expand Down Expand Up @@ -120,7 +119,7 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
'scaling_groups': None,
'presets': [],
}
log.info('CHECK_PRESETS (ak:{}, g:{}, sg:{})',
log.info('RESOURCE.CHECK_PRESETS (ak:{}, g:{}, sg:{})',
request['keypair']['access_key'], params['group'], params['scaling_group'])

async with root_ctx.db.begin_readonly() as conn:
Expand All @@ -130,21 +129,7 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
keypair_remaining = keypair_limits - keypair_occupied

# Check group resource limit and get group_id.
j = sa.join(
groups, association_groups_users,
association_groups_users.c.group_id == groups.c.id,
)
query = (
sa.select([groups.c.id, groups.c.total_resource_slots])
.select_from(j)
.where(
(association_groups_users.c.user_id == request['user']['uuid']) &
(groups.c.name == params['group']) &
(domains.c.name == domain_name)
)
)
result = await conn.execute(query)
row = result.first()
row = await get_groups_info_by_row(conn, request, params, domain_name)
group_id = row['id']
group_resource_slots = row['total_resource_slots']
if group_id is None:
Expand Down Expand Up @@ -178,50 +163,12 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
domain_remaining[slot],
)

# Prepare per scaling group resource.
sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key)
sgroup_names = [sg.name for sg in sgroups]
if params['scaling_group'] is not None:
if params['scaling_group'] not in sgroup_names:
raise InvalidAPIParameters('Unknown scaling group')
sgroup_names = [params['scaling_group']]
per_sgroup = {
sgname: {
'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
} for sgname in sgroup_names
}

# Per scaling group resource using from resource occupying kernels.
query = (
sa.select([kernels.c.occupied_slots, kernels.c.scaling_group])
.select_from(kernels)
.where(
(kernels.c.user_uuid == request['user']['uuid']) &
(kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
(kernels.c.scaling_group.in_(sgroup_names))
)
)
async for row in (await conn.stream(query)):
per_sgroup[row['scaling_group']]['using'] += row['occupied_slots']

# Per scaling group resource remaining from agents stats.
sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
query = (
sa.select([agents.c.available_slots, agents.c.occupied_slots, agents.c.scaling_group])
.select_from(agents)
.where(
(agents.c.status == AgentStatus.ALIVE) &
(agents.c.scaling_group.in_(sgroup_names))
)
# Take resources per sgroup.
per_sgroup, sgroup_remaining, agent_slots = await check_scaling_group_resource(
conn, request, params,
domain_name, group_id,
access_key, known_slot_types
)
agent_slots = []
async for row in (await conn.stream(query)):
remaining = row['available_slots'] - row['occupied_slots']
remaining += ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
sgroup_remaining += remaining
agent_slots.append(remaining)
per_sgroup[row['scaling_group']]['remaining'] += remaining

# Take maximum allocatable resources per sgroup.
for sgname, sgfields in per_sgroup.items():
Expand Down Expand Up @@ -255,13 +202,9 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
})

# Return group resource status as NaN if not allowed.
group_resource_visibility = \
await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility')
group_resource_visibility = t.ToBool().check(group_resource_visibility)
if not group_resource_visibility:
group_limits = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_occupied = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_remaining = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_limits, group_occupied, group_remaining = await check_group_resource(
root_ctx, t, known_slot_types
)

resp['keypair_limits'] = keypair_limits.to_json()
resp['keypair_using'] = keypair_occupied.to_json()
Expand All @@ -274,6 +217,106 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
return web.json_response(resp, status=200)


@atomic
@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
tx.AliasedKey(['group', 'name'], default='default'): t.String,
}))
async def check_group(request: web.Request, params: Any) -> web.Response:
"""
Returns the list of specific group's available resources.

:param scaling_group: If not None, get available resources of specific scaling group.
:param group: Get available resources of specific group (project) and enumerate them.
"""
root_ctx: RootContext = request.app['_root.context']
try:
access_key = request['keypair']['access_key']
domain_name = request['user']['domain_name']
# TODO: uncomment when we implement scaling group.
# scaling_group = request.query.get('scaling_group')
# assert scaling_group is not None, 'scaling_group parameter is missing.'
except (json.decoder.JSONDecodeError, AssertionError) as e:
raise InvalidAPIParameters(extra_msg=str(e.args[0]))
known_slot_types = await root_ctx.shared_config.get_resource_slots()
resp: MutableMapping[str, Any] = {
'scaling_group_remaining': None,
'scaling_groups': None,
}
log.info("RESOURCE.CHECK_GROUP(ak:{}, g:{})",
request['keypair']['access_key'], params['group'])

async with root_ctx.db.begin_readonly() as conn:
# Check group resource limit and get group_id.
row = await get_groups_info_by_row(conn, request, params, domain_name)
group_id = row['id']
group_resource_visibility = \
await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility')
group_resource_visibility = t.ToBool().check(group_resource_visibility)
group_limits, group_occupied, group_remaining = await check_group_resource(
conn, row, known_slot_types,
)

# 8<---- TODO: work-in-progress from here

sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key)
for sgroup in sgroups:
sgroup_capacity, sgroup_remaining = await check_scaling_group_resource(
conn, sgroup['name'], known_slot_types,
)
per_sgroup = {
sgroup['name']: {
'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
} for sgroup in sgroups
}

# Take maximum allocatable resources per sgroup.
for sgname, sgfields in per_sgroup.items():
for rtype, slots in sgfields.items():
per_sgroup[sgname][rtype] = slots.to_json() # type: ignore # it's serialization

# 8<---- TODO: work-in-progress until here

resp['limits'] = group_limits.to_json()
resp['occupied'] = group_occupied.to_json()
resp['remaining'] = group_remaining.to_json()
resp['scaling_groups'] = per_sgroup
return web.json_response(resp, status=200)


@atomic
@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
tx.AliasedKey(['scaling_group', 'name'], default='default'): t.String,
}))
async def check_scaling_group(request: web.Request, params: Any) -> web.Response:
"""
Returns the list of specific group's available resources.

:param scaling_group: If not None, get available resources of specific scaling group.
:param group: Get available resources of specific group (project) and enumerate them.
"""
root_ctx: RootContext = request.app['_root.context']
known_slot_types = await root_ctx.shared_config.get_resource_slots()
log.info("RESOURCE.CHECK_SCALING_GROUP (ak:{}, sg:{})",
request['keypair']['access_key'], params['scaling_group'])
async with root_ctx.db.begin_readonly() as conn:
# Check group resource limit and get group_id.
sgroup_capacity, sgroup_remaining = await check_scaling_group_resource(
conn, params['scaling_group'], known_slot_types
)
# TODO: include capacity in the response (when queried by super-admin)?
resp = {
"remaining": sgroup_remaining.to_json(),
}
return web.json_response(resp, status=200)


@server_status_required(READ_ALLOWED)
@superadmin_required
@atomic
Expand All @@ -284,7 +327,7 @@ async def recalculate_usage(request: web.Request) -> web.Response:
Those two values are sometimes out of sync. In that case, calling this API
re-calculates the values for running containers and updates them in DB.
"""
log.info('RECALCULATE_USAGE ()')
log.info('RESOURCE.RECALCULATE_USAGE ()')
root_ctx: RootContext = request.app['_root.context']
await root_ctx.registry.recalc_resource_usage()
return web.json_response({}, status=200)
Expand Down Expand Up @@ -436,7 +479,7 @@ async def usage_per_month(request: web.Request, params: Any) -> web.Response:
:param group_ids: If not None, query containers only in those groups.
:param month: The year-month to query usage statistics. ex) "202006" to query for Jun 2020
"""
log.info('USAGE_PER_MONTH (g:[{}], month:{})',
log.info('RESOURCE.USAGE_PER_MONTH (g:[{}], month:{})',
','.join(params['group_ids']), params['month'])
root_ctx: RootContext = request.app['_root.context']
local_tz = root_ctx.shared_config['system']['timezone']
Expand Down Expand Up @@ -483,7 +526,7 @@ async def usage_per_period(request: web.Request, params: Any) -> web.Response:
raise InvalidAPIParameters(extra_msg='Invalid date values')
if end_date <= start_date:
raise InvalidAPIParameters(extra_msg='end_date must be later than start_date.')
log.info('USAGE_PER_MONTH (g:{}, start_date:{}, end_date:{})',
log.info('RESOURCE.USAGE_PER_MONTH (g:{}, start_date:{}, end_date:{})',
group_id, start_date, end_date)
group_ids = [group_id] if group_id is not None else None
resp = await get_container_stats_for_period(request, start_date, end_date, group_ids=group_ids)
Expand Down Expand Up @@ -608,7 +651,7 @@ async def user_month_stats(request: web.Request) -> web.Response:
"""
access_key = request['keypair']['access_key']
user_uuid = request['user']['uuid']
log.info('USER_LAST_MONTH_STATS (ak:{}, u:{})', access_key, user_uuid)
log.info('RESOURCE.USER_LAST_MONTH_STATS (ak:{}, u:{})', access_key, user_uuid)
stats = await get_time_binned_monthly_stats(request, user_uuid=user_uuid)
return web.json_response(stats, status=200)

Expand All @@ -620,7 +663,7 @@ async def admin_month_stats(request: web.Request) -> web.Response:
Return time-binned (15 min) stats for all terminated sessions
over last 30 days.
"""
log.info('ADMIN_LAST_MONTH_STATS ()')
log.info('RESOURCE.ADMIN_LAST_MONTH_STATS ()')
stats = await get_time_binned_monthly_stats(request, user_uuid=None)
return web.json_response(stats, status=200)

Expand Down Expand Up @@ -657,7 +700,7 @@ async def get_watcher_info(request: web.Request, agent_id: str) -> dict:
tx.AliasedKey(['agent_id', 'agent']): t.String,
}))
async def get_watcher_status(request: web.Request, params: Any) -> web.Response:
log.info('GET_WATCHER_STATUS ()')
log.info('RESOURCE.WATCHER.GET_STATUS (ag:{})', params['agent_id'])
watcher_info = await get_watcher_info(request, params['agent_id'])
connector = aiohttp.TCPConnector()
async with aiohttp.ClientSession(connector=connector) as sess:
Expand All @@ -679,7 +722,7 @@ async def get_watcher_status(request: web.Request, params: Any) -> web.Response:
tx.AliasedKey(['agent_id', 'agent']): t.String,
}))
async def watcher_agent_start(request: web.Request, params: Any) -> web.Response:
log.info('WATCHER_AGENT_START ()')
log.info('RESOURCE.WATCHER.AGENT.START (ag:{})', params['agent_id'])
watcher_info = await get_watcher_info(request, params['agent_id'])
connector = aiohttp.TCPConnector()
async with aiohttp.ClientSession(connector=connector) as sess:
Expand All @@ -702,7 +745,7 @@ async def watcher_agent_start(request: web.Request, params: Any) -> web.Response
tx.AliasedKey(['agent_id', 'agent']): t.String,
}))
async def watcher_agent_stop(request: web.Request, params: Any) -> web.Response:
log.info('WATCHER_AGENT_STOP ()')
log.info('RESOURCE.WATCHER.AGENT.STOP (ag:{})', params['agent_id'])
watcher_info = await get_watcher_info(request, params['agent_id'])
connector = aiohttp.TCPConnector()
async with aiohttp.ClientSession(connector=connector) as sess:
Expand All @@ -725,7 +768,7 @@ async def watcher_agent_stop(request: web.Request, params: Any) -> web.Response:
tx.AliasedKey(['agent_id', 'agent']): t.String,
}))
async def watcher_agent_restart(request: web.Request, params: Any) -> web.Response:
log.info('WATCHER_AGENT_RESTART ()')
log.info('RESOURCE.WATCHER.AGENT.RESTART (ag:{})', params['agent_id'])
watcher_info = await get_watcher_info(request, params['agent_id'])
connector = aiohttp.TCPConnector()
async with aiohttp.ClientSession(connector=connector) as sess:
Expand All @@ -748,6 +791,8 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter
add_route = app.router.add_route
cors.add(add_route('GET', '/presets', list_presets))
cors.add(add_route('POST', '/check-presets', check_presets))
cors.add(add_route('GET', '/group', check_group))
cors.add(add_route('GET', '/scaling-group', check_scaling_group))
cors.add(add_route('POST', '/recalculate-usage', recalculate_usage))
cors.add(add_route('GET', '/usage/month', usage_per_month))
cors.add(add_route('GET', '/usage/period', usage_per_period))
Expand Down
Loading