Skip to content

Commit f96ddce

Browse files
authored
fix(client): use a unified rate limiter for all ACR API calls (#51)
* fix(client): use a unified rate limiter for all ACR API calls * fix(client): fix client unit test error
1 parent 2724ca8 commit f96ddce

File tree

2 files changed

+100
-35
lines changed

2 files changed

+100
-35
lines changed

src/agentcore_rl_toolkit/client.py

Lines changed: 99 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,60 @@ def _format_exception(exc: BaseException) -> str:
2121
return "".join(traceback.format_exception(exc))
2222

2323

24+
class ACRRateLimiter:
25+
"""Unified rate limiter for all ACR API calls (invoke + stop).
26+
27+
Since invoke_agent_runtime and stop_runtime_session share a single
28+
per-runtime-ARN rate limit on the ACR service, all calls must go
29+
through the same limiter to avoid throttling.
30+
31+
Provides both sync and async interfaces. The async interface uses
32+
an asyncio.Lock to serialize timing checks without blocking the
33+
event loop during the sleep interval.
34+
"""
35+
36+
def __init__(self, tps_limit: int = 25):
37+
self.tps_limit = tps_limit
38+
self._min_interval = 1.0 / tps_limit
39+
self._last_call_time = 0.0
40+
# Async lock and its event loop (lazily created)
41+
self._async_lock = None
42+
self._async_lock_loop = None
43+
44+
def _get_async_lock(self) -> asyncio.Lock:
45+
"""Lazily create and return the async rate-limiting lock.
46+
47+
Detects when the running event loop has changed (e.g., due to a new
48+
``asyncio.run()`` call) and recreates the lock for the current loop.
49+
"""
50+
loop = asyncio.get_running_loop()
51+
if self._async_lock is None or self._async_lock_loop is not loop:
52+
self._async_lock = asyncio.Lock()
53+
self._async_lock_loop = loop
54+
return self._async_lock
55+
56+
def wait_sync(self):
57+
"""Block until the next call is allowed under the TPS limit."""
58+
now = time.time()
59+
elapsed = now - self._last_call_time
60+
if elapsed < self._min_interval:
61+
time.sleep(self._min_interval - elapsed)
62+
self._last_call_time = time.time()
63+
64+
async def wait_async(self):
65+
"""Async wait until the next call is allowed under the TPS limit.
66+
67+
Uses a lock to serialize timing checks. The lock is held only during
68+
the timing check and sleep, so concurrent callers queue up properly.
69+
"""
70+
async with self._get_async_lock():
71+
now = time.time()
72+
elapsed = now - self._last_call_time
73+
if elapsed < self._min_interval:
74+
await asyncio.sleep(self._min_interval - elapsed)
75+
self._last_call_time = time.time()
76+
77+
2478
@dataclass
2579
class BatchItem:
2680
"""Result wrapper for batch execution, distinguishing success from error.
@@ -55,6 +109,7 @@ def __init__(
55109
input_id: str = None,
56110
agentcore_client=None,
57111
agent_runtime_arn: str = None,
112+
rate_limiter: ACRRateLimiter = None,
58113
):
59114
self.s3_client = s3_client
60115
self.s3_bucket = s3_bucket
@@ -63,6 +118,7 @@ def __init__(
63118
self.input_id = input_id
64119
self.agentcore_client = agentcore_client
65120
self.agent_runtime_arn = agent_runtime_arn
121+
self._rate_limiter = rate_limiter
66122
self._result = None
67123
self._done = False
68124
self._cancelled = False
@@ -111,7 +167,7 @@ def elapsed(self) -> float:
111167
return time.time() - self._start_time
112168

113169
def cancel(self) -> bool:
114-
"""Cancel the underlying ACR session (best-effort).
170+
"""Cancel the underlying ACR session (best-effort, rate-limited).
115171
116172
Sets cancelled to True once called, even if the API call fails or
117173
the client/session_id are unavailable. Use ``cancelled`` to check
@@ -124,6 +180,8 @@ def cancel(self) -> bool:
124180
if not self.agentcore_client or not self.session_id:
125181
return False
126182
try:
183+
if self._rate_limiter:
184+
self._rate_limiter.wait_sync()
127185
self.agentcore_client.stop_runtime_session(
128186
agentRuntimeArn=self.agent_runtime_arn,
129187
runtimeSessionId=self.session_id,
@@ -134,6 +192,31 @@ def cancel(self) -> bool:
134192
logger.warning(f"Failed to stop session {self.session_id[:8]}...: {e}")
135193
return False
136194

195+
async def cancel_async(self) -> bool:
196+
"""Async version of ``cancel()`` with rate limiting.
197+
198+
Uses the shared ACR rate limiter to avoid bursting stop calls
199+
that compete with invoke calls for the same service-side rate limit.
200+
"""
201+
if self._cancelled:
202+
return False
203+
self._cancelled = True
204+
if not self.agentcore_client or not self.session_id:
205+
return False
206+
try:
207+
if self._rate_limiter:
208+
await self._rate_limiter.wait_async()
209+
await asyncio.to_thread(
210+
self.agentcore_client.stop_runtime_session,
211+
agentRuntimeArn=self.agent_runtime_arn,
212+
runtimeSessionId=self.session_id,
213+
)
214+
logger.info(f"Stopped session {self.session_id[:8]}...")
215+
return True
216+
except Exception as e:
217+
logger.warning(f"Failed to stop session {self.session_id[:8]}...: {e}")
218+
return False
219+
137220
@property
138221
def cancelled(self) -> bool:
139222
"""True if cancellation was attempted (may not have succeeded)."""
@@ -176,7 +259,7 @@ async def _async_poll(self) -> dict:
176259
while True:
177260
if await self.done_async():
178261
self._result = await asyncio.to_thread(self._fetch_result)
179-
await asyncio.to_thread(self.cancel)
262+
await self.cancel_async()
180263
return self._result
181264
await asyncio.sleep(self._poll_interval)
182265

@@ -204,7 +287,7 @@ async def result_async(self, timeout: float = None) -> dict:
204287
return await asyncio.wait_for(self._async_poll(), timeout=timeout)
205288
return await self
206289
except (TimeoutError, asyncio.TimeoutError):
207-
await asyncio.to_thread(self.cancel)
290+
await self.cancel_async()
208291
raise
209292

210293
def result(self, timeout: float = None) -> dict:
@@ -327,9 +410,10 @@ def __init__(
327410
self.agentcore_client = boto3.client("bedrock-agentcore", region_name=self.region, config=config)
328411
self.s3_client = boto3.client("s3", region_name=self.region, config=config)
329412

330-
# Rate limiting state
331-
self._last_invoke_time = 0.0
332-
self._min_invoke_interval = 1.0 / tps_limit
413+
# Unified rate limiter for all ACR API calls (invoke + stop).
414+
# invoke_agent_runtime and stop_runtime_session share a single
415+
# per-runtime-ARN rate limit on the ACR service.
416+
self._rate_limiter = ACRRateLimiter(tps_limit)
333417

334418
def _parse_response(self, response: dict) -> dict:
335419
"""Parse ACR invocation response."""
@@ -355,12 +439,7 @@ def _build_full_payload(self, payload: dict, input_id: str, **overrides) -> dict
355439

356440
def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, **overrides) -> RolloutFuture:
357441
"""Invoke with TPS rate limiting."""
358-
# Enforce TPS limit
359-
now = time.time()
360-
elapsed = now - self._last_invoke_time
361-
if elapsed < self._min_invoke_interval:
362-
time.sleep(self._min_invoke_interval - elapsed)
363-
self._last_invoke_time = time.time()
442+
self._rate_limiter.wait_sync()
364443

365444
full_payload = self._build_full_payload(payload, input_id, **overrides)
366445

@@ -383,34 +462,19 @@ def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str, **
383462
input_id=input_id,
384463
agentcore_client=self.agentcore_client,
385464
agent_runtime_arn=self.agent_runtime_arn,
465+
rate_limiter=self._rate_limiter,
386466
)
387467

388-
def _get_async_lock(self) -> asyncio.Lock:
389-
"""Lazily create and return the async rate-limiting lock.
390-
391-
Detects when the running event loop has changed (e.g., due to a new
392-
``asyncio.run()`` call) and recreates the lock for the current loop.
393-
"""
394-
loop = asyncio.get_running_loop()
395-
if not hasattr(self, "_async_lock") or self._async_lock_loop is not loop:
396-
self._async_lock = asyncio.Lock()
397-
self._async_lock_loop = loop
398-
return self._async_lock
399-
400468
async def _async_rate_limited_invoke(
401469
self, payload: dict, session_id: str, input_id: str, **overrides
402470
) -> RolloutFuture:
403471
"""Invoke with async TPS rate limiting.
404472
405-
The lock is held only during the timing check and released before the
406-
HTTP call, so cold starts on one request don't block submission of others.
473+
The rate limiter lock is held only during the timing check and released
474+
before the HTTP call, so cold starts on one request don't block
475+
submission of others.
407476
"""
408-
async with self._get_async_lock():
409-
now = time.time()
410-
elapsed = now - self._last_invoke_time
411-
if elapsed < self._min_invoke_interval:
412-
await asyncio.sleep(self._min_invoke_interval - elapsed)
413-
self._last_invoke_time = time.time()
477+
await self._rate_limiter.wait_async()
414478

415479
full_payload = self._build_full_payload(payload, input_id, **overrides)
416480

@@ -435,6 +499,7 @@ def _invoke_and_parse():
435499
input_id=input_id,
436500
agentcore_client=self.agentcore_client,
437501
agent_runtime_arn=self.agent_runtime_arn,
502+
rate_limiter=self._rate_limiter,
438503
)
439504

440505
def invoke(self, payload: dict, session_id: str = None, input_id: str = None, **rollout_overrides) -> RolloutFuture:
@@ -748,7 +813,7 @@ async def _run(self):
748813
for key, (idx, future) in active_futures.items():
749814
if future.elapsed() > self.timeout:
750815
completed_keys.append(key)
751-
await asyncio.to_thread(future.cancel)
816+
await future.cancel_async()
752817
yield BatchItem(
753818
success=False,
754819
error=f"Timeout after {self.timeout}s",
@@ -763,7 +828,7 @@ async def _run(self):
763828
except Exception as e:
764829
yield BatchItem(success=False, error=_format_exception(e), index=idx, elapsed=future.elapsed())
765830
finally:
766-
await asyncio.to_thread(future.cancel)
831+
await future.cancel_async()
767832

768833
for key in completed_keys:
769834
del active_futures[key]

tests/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def test_get_async_lock_recreated_for_new_event_loop(self):
481481
locks = []
482482

483483
async def capture_lock():
484-
locks.append(client._get_async_lock())
484+
locks.append(client._rate_limiter._get_async_lock())
485485

486486
asyncio.run(capture_lock())
487487
asyncio.run(capture_lock()) # new event loop

0 commit comments

Comments
 (0)