Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 99 additions & 34 deletions src/agentcore_rl_toolkit/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,60 @@ def _format_exception(exc: BaseException) -> str:
return "".join(traceback.format_exception(exc))


class ACRRateLimiter:
"""Unified rate limiter for all ACR API calls (invoke + stop).

Since invoke_agent_runtime and stop_runtime_session share a single
per-runtime-ARN rate limit on the ACR service, all calls must go
through the same limiter to avoid throttling.

Provides both sync and async interfaces. The async interface uses
an asyncio.Lock to serialize timing checks without blocking the
event loop during the sleep interval.
"""

def __init__(self, tps_limit: int = 25):
self.tps_limit = tps_limit
self._min_interval = 1.0 / tps_limit
self._last_call_time = 0.0
# Async lock and its event loop (lazily created)
self._async_lock = None
self._async_lock_loop = None

def _get_async_lock(self) -> asyncio.Lock:
"""Lazily create and return the async rate-limiting lock.

Detects when the running event loop has changed (e.g., due to a new
``asyncio.run()`` call) and recreates the lock for the current loop.
"""
loop = asyncio.get_running_loop()
if self._async_lock is None or self._async_lock_loop is not loop:
self._async_lock = asyncio.Lock()
self._async_lock_loop = loop
return self._async_lock

def wait_sync(self):
"""Block until the next call is allowed under the TPS limit."""
now = time.time()
elapsed = now - self._last_call_time
if elapsed < self._min_interval:
time.sleep(self._min_interval - elapsed)
self._last_call_time = time.time()

async def wait_async(self):
"""Async wait until the next call is allowed under the TPS limit.

Uses a lock to serialize timing checks. The lock is held only during
the timing check and sleep, so concurrent callers queue up properly.
"""
async with self._get_async_lock():
now = time.time()
elapsed = now - self._last_call_time
if elapsed < self._min_interval:
await asyncio.sleep(self._min_interval - elapsed)
self._last_call_time = time.time()


@dataclass
class BatchItem:
"""Result wrapper for batch execution, distinguishing success from error.
Expand Down Expand Up @@ -55,6 +109,7 @@ def __init__(
input_id: str = None,
agentcore_client=None,
agent_runtime_arn: str = None,
rate_limiter: ACRRateLimiter = None,
):
self.s3_client = s3_client
self.s3_bucket = s3_bucket
Expand All @@ -63,6 +118,7 @@ def __init__(
self.input_id = input_id
self.agentcore_client = agentcore_client
self.agent_runtime_arn = agent_runtime_arn
self._rate_limiter = rate_limiter
self._result = None
self._done = False
self._cancelled = False
Expand Down Expand Up @@ -111,7 +167,7 @@ def elapsed(self) -> float:
return time.time() - self._start_time

def cancel(self) -> bool:
"""Cancel the underlying ACR session (best-effort).
"""Cancel the underlying ACR session (best-effort, rate-limited).

Sets cancelled to True once called, even if the API call fails or
the client/session_id are unavailable. Use ``cancelled`` to check
Expand All @@ -124,6 +180,8 @@ def cancel(self) -> bool:
if not self.agentcore_client or not self.session_id:
return False
try:
if self._rate_limiter:
self._rate_limiter.wait_sync()
self.agentcore_client.stop_runtime_session(
agentRuntimeArn=self.agent_runtime_arn,
runtimeSessionId=self.session_id,
Expand All @@ -134,6 +192,31 @@ def cancel(self) -> bool:
logger.warning(f"Failed to stop session {self.session_id[:8]}...: {e}")
return False

async def cancel_async(self) -> bool:
"""Async version of ``cancel()`` with rate limiting.

Uses the shared ACR rate limiter to avoid bursting stop calls
that compete with invoke calls for the same service-side rate limit.
"""
if self._cancelled:
return False
self._cancelled = True
if not self.agentcore_client or not self.session_id:
return False
try:
if self._rate_limiter:
await self._rate_limiter.wait_async()
await asyncio.to_thread(
self.agentcore_client.stop_runtime_session,
agentRuntimeArn=self.agent_runtime_arn,
runtimeSessionId=self.session_id,
)
logger.info(f"Stopped session {self.session_id[:8]}...")
return True
except Exception as e:
logger.warning(f"Failed to stop session {self.session_id[:8]}...: {e}")
return False

@property
def cancelled(self) -> bool:
"""True if cancellation was attempted (may not have succeeded)."""
Expand Down Expand Up @@ -176,7 +259,7 @@ async def _async_poll(self) -> dict:
while True:
if await self.done_async():
self._result = await asyncio.to_thread(self._fetch_result)
await asyncio.to_thread(self.cancel)
await self.cancel_async()
return self._result
await asyncio.sleep(self._poll_interval)

Expand Down Expand Up @@ -204,7 +287,7 @@ async def result_async(self, timeout: float = None) -> dict:
return await asyncio.wait_for(self._async_poll(), timeout=timeout)
return await self
except (TimeoutError, asyncio.TimeoutError):
await asyncio.to_thread(self.cancel)
await self.cancel_async()
raise

def result(self, timeout: float = None) -> dict:
Expand Down Expand Up @@ -327,9 +410,10 @@ def __init__(
self.agentcore_client = boto3.client("bedrock-agentcore", region_name=self.region, config=config)
self.s3_client = boto3.client("s3", region_name=self.region, config=config)

# Rate limiting state
self._last_invoke_time = 0.0
self._min_invoke_interval = 1.0 / tps_limit
# Unified rate limiter for all ACR API calls (invoke + stop).
# invoke_agent_runtime and stop_runtime_session share a single
# per-runtime-ARN rate limit on the ACR service.
self._rate_limiter = ACRRateLimiter(tps_limit)

def _parse_response(self, response: dict) -> dict:
"""Parse ACR invocation response."""
Expand All @@ -355,12 +439,7 @@ def _build_full_payload(self, payload: dict, input_id: str, **overrides) -> dict

def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, **overrides) -> RolloutFuture:
"""Invoke with TPS rate limiting."""
# Enforce TPS limit
now = time.time()
elapsed = now - self._last_invoke_time
if elapsed < self._min_invoke_interval:
time.sleep(self._min_invoke_interval - elapsed)
self._last_invoke_time = time.time()
self._rate_limiter.wait_sync()

full_payload = self._build_full_payload(payload, input_id, **overrides)

Expand All @@ -383,34 +462,19 @@ def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, **
input_id=input_id,
agentcore_client=self.agentcore_client,
agent_runtime_arn=self.agent_runtime_arn,
rate_limiter=self._rate_limiter,
)

def _get_async_lock(self) -> asyncio.Lock:
"""Lazily create and return the async rate-limiting lock.

Detects when the running event loop has changed (e.g., due to a new
``asyncio.run()`` call) and recreates the lock for the current loop.
"""
loop = asyncio.get_running_loop()
if not hasattr(self, "_async_lock") or self._async_lock_loop is not loop:
self._async_lock = asyncio.Lock()
self._async_lock_loop = loop
return self._async_lock

async def _async_rate_limited_invoke(
self, payload: dict, session_id: str, input_id: str, **overrides
) -> RolloutFuture:
"""Invoke with async TPS rate limiting.

The lock is held only during the timing check and released before the
HTTP call, so cold starts on one request don't block submission of others.
The rate limiter lock is held only during the timing check and released
before the HTTP call, so cold starts on one request don't block
submission of others.
"""
async with self._get_async_lock():
now = time.time()
elapsed = now - self._last_invoke_time
if elapsed < self._min_invoke_interval:
await asyncio.sleep(self._min_invoke_interval - elapsed)
self._last_invoke_time = time.time()
await self._rate_limiter.wait_async()

full_payload = self._build_full_payload(payload, input_id, **overrides)

Expand All @@ -435,6 +499,7 @@ def _invoke_and_parse():
input_id=input_id,
agentcore_client=self.agentcore_client,
agent_runtime_arn=self.agent_runtime_arn,
rate_limiter=self._rate_limiter,
)

def invoke(self, payload: dict, session_id: str = None, input_id: str = None, **rollout_overrides) -> RolloutFuture:
Expand Down Expand Up @@ -748,7 +813,7 @@ async def _run(self):
for key, (idx, future) in active_futures.items():
if future.elapsed() > self.timeout:
completed_keys.append(key)
await asyncio.to_thread(future.cancel)
await future.cancel_async()
yield BatchItem(
success=False,
error=f"Timeout after {self.timeout}s",
Expand All @@ -763,7 +828,7 @@ async def _run(self):
except Exception as e:
yield BatchItem(success=False, error=_format_exception(e), index=idx, elapsed=future.elapsed())
finally:
await asyncio.to_thread(future.cancel)
await future.cancel_async()

for key in completed_keys:
del active_futures[key]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def test_get_async_lock_recreated_for_new_event_loop(self):
locks = []

async def capture_lock():
locks.append(client._get_async_lock())
locks.append(client._rate_limiter._get_async_lock())

asyncio.run(capture_lock())
asyncio.run(capture_lock()) # new event loop
Expand Down
Loading