@@ -21,6 +21,60 @@ def _format_exception(exc: BaseException) -> str:
2121 return "" .join (traceback .format_exception (exc ))
2222
2323
24+ class ACRRateLimiter :
25+ """Unified rate limiter for all ACR API calls (invoke + stop).
26+
27+ Since invoke_agent_runtime and stop_runtime_session share a single
28+ per-runtime-ARN rate limit on the ACR service, all calls must go
29+ through the same limiter to avoid throttling.
30+
31+ Provides both sync and async interfaces. The async interface uses
32+ an asyncio.Lock to serialize timing checks without blocking the
33+ event loop during the sleep interval.
34+ """
35+
36+ def __init__ (self , tps_limit : int = 25 ):
37+ self .tps_limit = tps_limit
38+ self ._min_interval = 1.0 / tps_limit
39+ self ._last_call_time = 0.0
40+ # Async lock and its event loop (lazily created)
41+ self ._async_lock = None
42+ self ._async_lock_loop = None
43+
44+ def _get_async_lock (self ) -> asyncio .Lock :
45+ """Lazily create and return the async rate-limiting lock.
46+
47+ Detects when the running event loop has changed (e.g., due to a new
48+ ``asyncio.run()`` call) and recreates the lock for the current loop.
49+ """
50+ loop = asyncio .get_running_loop ()
51+ if self ._async_lock is None or self ._async_lock_loop is not loop :
52+ self ._async_lock = asyncio .Lock ()
53+ self ._async_lock_loop = loop
54+ return self ._async_lock
55+
56+ def wait_sync (self ):
57+ """Block until the next call is allowed under the TPS limit."""
58+ now = time .time ()
59+ elapsed = now - self ._last_call_time
60+ if elapsed < self ._min_interval :
61+ time .sleep (self ._min_interval - elapsed )
62+ self ._last_call_time = time .time ()
63+
64+ async def wait_async (self ):
65+ """Async wait until the next call is allowed under the TPS limit.
66+
67+ Uses a lock to serialize timing checks. The lock is held only during
68+ the timing check and sleep, so concurrent callers queue up properly.
69+ """
70+ async with self ._get_async_lock ():
71+ now = time .time ()
72+ elapsed = now - self ._last_call_time
73+ if elapsed < self ._min_interval :
74+ await asyncio .sleep (self ._min_interval - elapsed )
75+ self ._last_call_time = time .time ()
76+
77+
2478@dataclass
2579class BatchItem :
2680 """Result wrapper for batch execution, distinguishing success from error.
@@ -55,6 +109,7 @@ def __init__(
55109 input_id : str = None ,
56110 agentcore_client = None ,
57111 agent_runtime_arn : str = None ,
112+ rate_limiter : ACRRateLimiter = None ,
58113 ):
59114 self .s3_client = s3_client
60115 self .s3_bucket = s3_bucket
@@ -63,6 +118,7 @@ def __init__(
63118 self .input_id = input_id
64119 self .agentcore_client = agentcore_client
65120 self .agent_runtime_arn = agent_runtime_arn
121+ self ._rate_limiter = rate_limiter
66122 self ._result = None
67123 self ._done = False
68124 self ._cancelled = False
@@ -111,7 +167,7 @@ def elapsed(self) -> float:
111167 return time .time () - self ._start_time
112168
113169 def cancel (self ) -> bool :
114- """Cancel the underlying ACR session (best-effort).
170+ """Cancel the underlying ACR session (best-effort, rate-limited ).
115171
116172 Sets cancelled to True once called, even if the API call fails or
117173 the client/session_id are unavailable. Use ``cancelled`` to check
@@ -124,6 +180,8 @@ def cancel(self) -> bool:
124180 if not self .agentcore_client or not self .session_id :
125181 return False
126182 try :
183+ if self ._rate_limiter :
184+ self ._rate_limiter .wait_sync ()
127185 self .agentcore_client .stop_runtime_session (
128186 agentRuntimeArn = self .agent_runtime_arn ,
129187 runtimeSessionId = self .session_id ,
@@ -134,6 +192,31 @@ def cancel(self) -> bool:
134192 logger .warning (f"Failed to stop session { self .session_id [:8 ]} ...: { e } " )
135193 return False
136194
195+ async def cancel_async (self ) -> bool :
196+ """Async version of ``cancel()`` with rate limiting.
197+
198+ Uses the shared ACR rate limiter to avoid bursting stop calls
199+ that compete with invoke calls for the same service-side rate limit.
200+ """
201+ if self ._cancelled :
202+ return False
203+ self ._cancelled = True
204+ if not self .agentcore_client or not self .session_id :
205+ return False
206+ try :
207+ if self ._rate_limiter :
208+ await self ._rate_limiter .wait_async ()
209+ await asyncio .to_thread (
210+ self .agentcore_client .stop_runtime_session ,
211+ agentRuntimeArn = self .agent_runtime_arn ,
212+ runtimeSessionId = self .session_id ,
213+ )
214+ logger .info (f"Stopped session { self .session_id [:8 ]} ..." )
215+ return True
216+ except Exception as e :
217+ logger .warning (f"Failed to stop session { self .session_id [:8 ]} ...: { e } " )
218+ return False
219+
137220 @property
138221 def cancelled (self ) -> bool :
139222 """True if cancellation was attempted (may not have succeeded)."""
@@ -176,7 +259,7 @@ async def _async_poll(self) -> dict:
176259 while True :
177260 if await self .done_async ():
178261 self ._result = await asyncio .to_thread (self ._fetch_result )
179- await asyncio . to_thread ( self .cancel )
262+ await self .cancel_async ( )
180263 return self ._result
181264 await asyncio .sleep (self ._poll_interval )
182265
@@ -204,7 +287,7 @@ async def result_async(self, timeout: float = None) -> dict:
204287 return await asyncio .wait_for (self ._async_poll (), timeout = timeout )
205288 return await self
206289 except (TimeoutError , asyncio .TimeoutError ):
207- await asyncio . to_thread ( self .cancel )
290+ await self .cancel_async ( )
208291 raise
209292
210293 def result (self , timeout : float = None ) -> dict :
@@ -327,9 +410,10 @@ def __init__(
327410 self .agentcore_client = boto3 .client ("bedrock-agentcore" , region_name = self .region , config = config )
328411 self .s3_client = boto3 .client ("s3" , region_name = self .region , config = config )
329412
330- # Rate limiting state
331- self ._last_invoke_time = 0.0
332- self ._min_invoke_interval = 1.0 / tps_limit
413+ # Unified rate limiter for all ACR API calls (invoke + stop).
414+ # invoke_agent_runtime and stop_runtime_session share a single
415+ # per-runtime-ARN rate limit on the ACR service.
416+ self ._rate_limiter = ACRRateLimiter (tps_limit )
333417
334418 def _parse_response (self , response : dict ) -> dict :
335419 """Parse ACR invocation response."""
@@ -355,12 +439,7 @@ def _build_full_payload(self, payload: dict, input_id: str, **overrides) -> dict
355439
356440 def _rate_limited_invoke (self , payload : dict , session_id : str , input_id : str , ** overrides ) -> RolloutFuture :
357441 """Invoke with TPS rate limiting."""
358- # Enforce TPS limit
359- now = time .time ()
360- elapsed = now - self ._last_invoke_time
361- if elapsed < self ._min_invoke_interval :
362- time .sleep (self ._min_invoke_interval - elapsed )
363- self ._last_invoke_time = time .time ()
442+ self ._rate_limiter .wait_sync ()
364443
365444 full_payload = self ._build_full_payload (payload , input_id , ** overrides )
366445
@@ -383,34 +462,19 @@ def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, **
383462 input_id = input_id ,
384463 agentcore_client = self .agentcore_client ,
385464 agent_runtime_arn = self .agent_runtime_arn ,
465+ rate_limiter = self ._rate_limiter ,
386466 )
387467
388- def _get_async_lock (self ) -> asyncio .Lock :
389- """Lazily create and return the async rate-limiting lock.
390-
391- Detects when the running event loop has changed (e.g., due to a new
392- ``asyncio.run()`` call) and recreates the lock for the current loop.
393- """
394- loop = asyncio .get_running_loop ()
395- if not hasattr (self , "_async_lock" ) or self ._async_lock_loop is not loop :
396- self ._async_lock = asyncio .Lock ()
397- self ._async_lock_loop = loop
398- return self ._async_lock
399-
400468 async def _async_rate_limited_invoke (
401469 self , payload : dict , session_id : str , input_id : str , ** overrides
402470 ) -> RolloutFuture :
403471 """Invoke with async TPS rate limiting.
404472
405- The lock is held only during the timing check and released before the
406- HTTP call, so cold starts on one request don't block submission of others.
473+ The rate limiter lock is held only during the timing check and released
474+ before the HTTP call, so cold starts on one request don't block
475+ submission of others.
407476 """
408- async with self ._get_async_lock ():
409- now = time .time ()
410- elapsed = now - self ._last_invoke_time
411- if elapsed < self ._min_invoke_interval :
412- await asyncio .sleep (self ._min_invoke_interval - elapsed )
413- self ._last_invoke_time = time .time ()
477+ await self ._rate_limiter .wait_async ()
414478
415479 full_payload = self ._build_full_payload (payload , input_id , ** overrides )
416480
@@ -435,6 +499,7 @@ def _invoke_and_parse():
435499 input_id = input_id ,
436500 agentcore_client = self .agentcore_client ,
437501 agent_runtime_arn = self .agent_runtime_arn ,
502+ rate_limiter = self ._rate_limiter ,
438503 )
439504
440505 def invoke (self , payload : dict , session_id : str = None , input_id : str = None , ** rollout_overrides ) -> RolloutFuture :
@@ -748,7 +813,7 @@ async def _run(self):
748813 for key , (idx , future ) in active_futures .items ():
749814 if future .elapsed () > self .timeout :
750815 completed_keys .append (key )
751- await asyncio . to_thread ( future .cancel )
816+ await future .cancel_async ( )
752817 yield BatchItem (
753818 success = False ,
754819 error = f"Timeout after { self .timeout } s" ,
@@ -763,7 +828,7 @@ async def _run(self):
763828 except Exception as e :
764829 yield BatchItem (success = False , error = _format_exception (e ), index = idx , elapsed = future .elapsed ())
765830 finally :
766- await asyncio . to_thread ( future .cancel )
831+ await future .cancel_async ( )
767832
768833 for key in completed_keys :
769834 del active_futures [key ]
0 commit comments