From 43db2c7a70ff06c9c1d59132eae699ae5808cbf7 Mon Sep 17 00:00:00 2001 From: Youzhi Luo Date: Mon, 6 Apr 2026 22:29:08 +0000 Subject: [PATCH 1/2] fix(client): use a unified rate limiter for all ACR API calls --- src/agentcore_rl_toolkit/client.py | 133 +++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 34 deletions(-) diff --git a/src/agentcore_rl_toolkit/client.py b/src/agentcore_rl_toolkit/client.py index d20df6d..cbd3e20 100644 --- a/src/agentcore_rl_toolkit/client.py +++ b/src/agentcore_rl_toolkit/client.py @@ -21,6 +21,60 @@ def _format_exception(exc: BaseException) -> str: return "".join(traceback.format_exception(exc)) +class ACRRateLimiter: + """Unified rate limiter for all ACR API calls (invoke + stop). + + Since invoke_agent_runtime and stop_runtime_session share a single + per-runtime-ARN rate limit on the ACR service, all calls must go + through the same limiter to avoid throttling. + + Provides both sync and async interfaces. The async interface uses + an asyncio.Lock to serialize timing checks without blocking the + event loop during the sleep interval. + """ + + def __init__(self, tps_limit: int = 25): + self.tps_limit = tps_limit + self._min_interval = 1.0 / tps_limit + self._last_call_time = 0.0 + # Async lock and its event loop (lazily created) + self._async_lock = None + self._async_lock_loop = None + + def _get_async_lock(self) -> asyncio.Lock: + """Lazily create and return the async rate-limiting lock. + + Detects when the running event loop has changed (e.g., due to a new + ``asyncio.run()`` call) and recreates the lock for the current loop. + """ + loop = asyncio.get_running_loop() + if self._async_lock is None or self._async_lock_loop is not loop: + self._async_lock = asyncio.Lock() + self._async_lock_loop = loop + return self._async_lock + + def wait_sync(self): + """Block until the next call is allowed under the TPS limit.""" + now = time.time() + elapsed = now - self._last_call_time + if elapsed < self._min_interval: + time.sleep(self._min_interval - elapsed) + self._last_call_time = time.time() + + async def wait_async(self): + """Async wait until the next call is allowed under the TPS limit. + + Uses a lock to serialize timing checks. The lock is held only during + the timing check and sleep, so concurrent callers queue up properly. + """ + async with self._get_async_lock(): + now = time.time() + elapsed = now - self._last_call_time + if elapsed < self._min_interval: + await asyncio.sleep(self._min_interval - elapsed) + self._last_call_time = time.time() + + @dataclass class BatchItem: """Result wrapper for batch execution, distinguishing success from error. @@ -55,6 +109,7 @@ def __init__( input_id: str = None, agentcore_client=None, agent_runtime_arn: str = None, + rate_limiter: ACRRateLimiter = None, ): self.s3_client = s3_client self.s3_bucket = s3_bucket @@ -63,6 +118,7 @@ def __init__( self.input_id = input_id self.agentcore_client = agentcore_client self.agent_runtime_arn = agent_runtime_arn + self._rate_limiter = rate_limiter self._result = None self._done = False self._cancelled = False @@ -111,7 +167,7 @@ def elapsed(self) -> float: return time.time() - self._start_time def cancel(self) -> bool: - """Cancel the underlying ACR session (best-effort). + """Cancel the underlying ACR session (best-effort, rate-limited). Sets cancelled to True once called, even if the API call fails or the client/session_id are unavailable. Use ``cancelled`` to check @@ -124,6 +180,8 @@ def cancel(self) -> bool: if not self.agentcore_client or not self.session_id: return False try: + if self._rate_limiter: + self._rate_limiter.wait_sync() self.agentcore_client.stop_runtime_session( agentRuntimeArn=self.agent_runtime_arn, runtimeSessionId=self.session_id, @@ -134,6 +192,31 @@ def cancel(self) -> bool: logger.warning(f"Failed to stop session {self.session_id[:8]}...: {e}") return False + async def cancel_async(self) -> bool: + """Async version of ``cancel()`` with rate limiting. + + Uses the shared ACR rate limiter to avoid bursting stop calls + that compete with invoke calls for the same service-side rate limit. + """ + if self._cancelled: + return False + self._cancelled = True + if not self.agentcore_client or not self.session_id: + return False + try: + if self._rate_limiter: + await self._rate_limiter.wait_async() + await asyncio.to_thread( + self.agentcore_client.stop_runtime_session, + agentRuntimeArn=self.agent_runtime_arn, + runtimeSessionId=self.session_id, + ) + logger.info(f"Stopped session {self.session_id[:8]}...") + return True + except Exception as e: + logger.warning(f"Failed to stop session {self.session_id[:8]}...: {e}") + return False + @property def cancelled(self) -> bool: """True if cancellation was attempted (may not have succeeded).""" @@ -176,7 +259,7 @@ async def _async_poll(self) -> dict: while True: if await self.done_async(): self._result = await asyncio.to_thread(self._fetch_result) - await asyncio.to_thread(self.cancel) + await self.cancel_async() return self._result await asyncio.sleep(self._poll_interval) @@ -204,7 +287,7 @@ async def result_async(self, timeout: float = None) -> dict: return await asyncio.wait_for(self._async_poll(), timeout=timeout) return await self except (TimeoutError, asyncio.TimeoutError): - await asyncio.to_thread(self.cancel) + await self.cancel_async() raise def result(self, timeout: float = None) -> dict: @@ -327,9 +410,10 @@ def __init__( self.agentcore_client = boto3.client("bedrock-agentcore", region_name=self.region, config=config) self.s3_client = boto3.client("s3", region_name=self.region, config=config) - # Rate limiting state - self._last_invoke_time = 0.0 - self._min_invoke_interval = 1.0 / tps_limit + # Unified rate limiter for all ACR API calls (invoke + stop). + # invoke_agent_runtime and stop_runtime_session share a single + # per-runtime-ARN rate limit on the ACR service. + self._rate_limiter = ACRRateLimiter(tps_limit) def _parse_response(self, response: dict) -> dict: """Parse ACR invocation response.""" @@ -355,12 +439,7 @@ def _build_full_payload(self, payload: dict, input_id: str, **overrides) -> dict def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, **overrides) -> RolloutFuture: """Invoke with TPS rate limiting.""" - # Enforce TPS limit - now = time.time() - elapsed = now - self._last_invoke_time - if elapsed < self._min_invoke_interval: - time.sleep(self._min_invoke_interval - elapsed) - self._last_invoke_time = time.time() + self._rate_limiter.wait_sync() full_payload = self._build_full_payload(payload, input_id, **overrides) @@ -383,34 +462,19 @@ def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, ** input_id=input_id, agentcore_client=self.agentcore_client, agent_runtime_arn=self.agent_runtime_arn, + rate_limiter=self._rate_limiter, ) - def _get_async_lock(self) -> asyncio.Lock: - """Lazily create and return the async rate-limiting lock. - - Detects when the running event loop has changed (e.g., due to a new - ``asyncio.run()`` call) and recreates the lock for the current loop. - """ - loop = asyncio.get_running_loop() - if not hasattr(self, "_async_lock") or self._async_lock_loop is not loop: - self._async_lock = asyncio.Lock() - self._async_lock_loop = loop - return self._async_lock - async def _async_rate_limited_invoke( self, payload: dict, session_id: str, input_id: str, **overrides ) -> RolloutFuture: """Invoke with async TPS rate limiting. - The lock is held only during the timing check and released before the - HTTP call, so cold starts on one request don't block submission of others. + The rate limiter lock is held only during the timing check and released + before the HTTP call, so cold starts on one request don't block + submission of others. """ - async with self._get_async_lock(): - now = time.time() - elapsed = now - self._last_invoke_time - if elapsed < self._min_invoke_interval: - await asyncio.sleep(self._min_invoke_interval - elapsed) - self._last_invoke_time = time.time() + await self._rate_limiter.wait_async() full_payload = self._build_full_payload(payload, input_id, **overrides) @@ -435,6 +499,7 @@ def _invoke_and_parse(): input_id=input_id, agentcore_client=self.agentcore_client, agent_runtime_arn=self.agent_runtime_arn, + rate_limiter=self._rate_limiter, ) def invoke(self, payload: dict, session_id: str = None, input_id: str = None, **rollout_overrides) -> RolloutFuture: @@ -748,7 +813,7 @@ async def _run(self): for key, (idx, future) in active_futures.items(): if future.elapsed() > self.timeout: completed_keys.append(key) - await asyncio.to_thread(future.cancel) + await future.cancel_async() yield BatchItem( success=False, error=f"Timeout after {self.timeout}s", @@ -763,7 +828,7 @@ async def _run(self): except Exception as e: yield BatchItem(success=False, error=_format_exception(e), index=idx, elapsed=future.elapsed()) finally: - await asyncio.to_thread(future.cancel) + await future.cancel_async() for key in completed_keys: del active_futures[key] From 4618e85a2d0a433efc9cb94bc88ecd8020a1c683 Mon Sep 17 00:00:00 2001 From: Youzhi Luo Date: Mon, 6 Apr 2026 22:35:57 +0000 Subject: [PATCH 2/2] fix(client): fix client unit test error --- tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 2ba4751..c738b07 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -481,7 +481,7 @@ def test_get_async_lock_recreated_for_new_event_loop(self): locks = [] async def capture_lock(): - locks.append(client._get_async_lock()) + locks.append(client._rate_limiter._get_async_lock()) asyncio.run(capture_lock()) asyncio.run(capture_lock()) # new event loop