66
77import zmq
88
9- from tensorrt_llm ._utils import nvtx_mark_debug
9+ from tensorrt_llm ._utils import (customized_gc_thresholds , nvtx_mark_debug ,
10+ nvtx_range_debug )
1011
11- from ..._utils import nvtx_range_debug
1212from ...llmapi .utils import (AsyncQueue , _SyncQueue , enable_llmapi_debug ,
1313 logger_debug )
1414from ...logger import logger
@@ -110,6 +110,7 @@ def __init__(self,
110110 self ._stop_event = None
111111 self ._loop = None
112112 self ._loop_thread = None
113+ self ._pending_get_task = None # Track pending socket read to avoid cancellation
113114
114115 logger_debug (f"RPC Client initialized. Connected to { self ._address } " )
115116
@@ -196,21 +197,34 @@ def _handle_regular_response(self, response: RPCResponse):
196197 future , target_loop = future_info
197198
198199 if not future .done ():
199- if response .error is None :
200- if enable_llmapi_debug () or logger .level == 'debug' :
200+
201+ def safe_set_result ():
202+ """Safely set result on future, handling race conditions."""
203+ try :
204+ if not future .done ():
205+ if response .error is None :
206+ future .set_result (response .result )
207+ else :
208+ future .set_exception (response .error )
209+ except asyncio .InvalidStateError :
210+ # Future was cancelled or completed between the check and set
211+ # This is expected in high-load scenarios, just log and continue
212+ if enable_llmapi_debug () or logger .level == 'debug' :
213+ logger_debug (
214+ f"Future already done for request_id: { response .request_id } , skipping"
215+ )
216+
217+ if enable_llmapi_debug () or logger .level == 'debug' :
218+ if response .error is None :
201219 logger_debug (
202220 f"Setting result for request_id: { response .request_id } "
203221 )
204- target_loop .call_soon_threadsafe (future .set_result ,
205- response .result )
206- else :
207- # Use the original RPCError from the response
208- if enable_llmapi_debug () or logger .level == 'debug' :
222+ else :
209223 logger_debug (
210224 f"Setting exception for request_id: { response .request_id } , error: { response .error } "
211225 )
212- target_loop . call_soon_threadsafe ( future . set_exception ,
213- response . error )
226+
227+ target_loop . call_soon_threadsafe ( safe_set_result )
214228 else :
215229 if enable_llmapi_debug () or logger .level == 'debug' :
216230 logger_debug (
@@ -229,8 +243,17 @@ async def _handle_reader_exception(self, exception: Exception):
229243 # Propagate exception to all pending futures
230244 for (future , target_loop ) in self ._pending_futures .values ():
231245 if not future .done ():
232- target_loop .call_soon_threadsafe (future .set_exception ,
233- exception )
246+
247+ def safe_set_exception (f = future , exc = exception ):
248+ """Safely set exception on future, handling race conditions."""
249+ try :
250+ if not f .done ():
251+ f .set_exception (exc )
252+ except asyncio .InvalidStateError :
253+ # Future was cancelled or completed, this is fine
254+ pass
255+
256+ target_loop .call_soon_threadsafe (safe_set_exception )
234257
235258 # Also signal error to streaming queues
236259 for queue in self ._streaming_queues .values ():
@@ -242,58 +265,84 @@ async def _wait_for_response(self) -> Optional[RPCResponse]:
242265 Returns:
243266 RPCResponse if available, None if timeout
244267 """
268+ # Reuse pending task or create new one to avoid cancelling ZMQ operations
269+ if self ._pending_get_task is None or self ._pending_get_task .done ():
270+ self ._pending_get_task = asyncio .create_task (
271+ self ._client_socket .get_async ())
272+
245273 try :
246- response : RPCResponse = await asyncio .wait_for (
247- self ._client_socket .get_async (),
248- timeout = 0.1 # Check stop event every 100ms
249- )
250- return response
251- except asyncio .TimeoutError :
252- # Timeout is expected - just check stop event and continue
253- return None
274+ # Use wait with a done callback instead of wait_for to avoid cancellation
275+ done , pending = await asyncio .wait (
276+ [self ._pending_get_task ],
277+ timeout = 0.1 , # Check stop event every 100ms
278+ return_when = asyncio .FIRST_COMPLETED )
279+
280+ if done :
281+ # Task completed, get the result
282+ response = await self ._pending_get_task
283+ self ._pending_get_task = None # Clear so we create a new one next time
284+ return response
285+ else :
286+ # Timeout - task is still pending, will be reused next iteration
287+ return None
288+
289+ except Exception as e :
290+ # If there's an error, clear the task
291+ self ._pending_get_task = None
292+ raise e
254293
255294 async def _response_reader (self ):
256295 """Task to read responses from the socket and set results on futures."""
257296 logger_debug ("Response reader started" )
258297
259- while not self ._stop_event .is_set ():
260- with nvtx_range_debug ("response_reader" ,
261- color = "cyan" ,
262- category = "RPC" ):
263- try :
264- response = await self ._wait_for_response ()
265-
266- if response is None :
267- continue
268-
269- nvtx_mark_debug (
270- f"RPC.response.{ 'streaming' if response .is_streaming else 'sync' } " ,
271- color = "black" ,
272- category = "RPC" )
273-
274- # Optimize: Check debug flag before expensive string operations
275- # This avoids holding GIL for f-string evaluation when debug is disabled
276- if enable_llmapi_debug () or logger .level == 'debug' :
277- logger_debug (
278- f"RPC Client received response: request_id={ response .request_id } , "
279- f"is_streaming={ response .is_streaming } , "
280- f"pending_futures={ len (self ._pending_futures )} " )
281-
282- with nvtx_range_debug ("handle_response" ,
283- color = "purple" ,
284- category = "RPC" ):
285- if response .is_streaming :
286- self ._handle_streaming_response (response )
287- else :
288- self ._handle_regular_response (response )
289-
290- except asyncio .CancelledError :
291- # Still handle cancellation for backward compatibility
292- logger_debug ("Response reader cancelled" )
293- break
294- except Exception as e :
295- await self ._handle_reader_exception (e )
296- break
298+ with customized_gc_thresholds (10000 ):
299+ while not self ._stop_event .is_set ():
300+ with nvtx_range_debug ("response_reader" ,
301+ color = "cyan" ,
302+ category = "RPC" ):
303+ try :
304+ response = await self ._wait_for_response ()
305+
306+ if response is None :
307+ continue
308+
309+ nvtx_mark_debug (
310+ f"RPC.response.{ 'streaming' if response .is_streaming else 'sync' } " ,
311+ color = "black" ,
312+ category = "RPC" )
313+
314+ # Optimize: Check debug flag before expensive string operations
315+ # This avoids holding GIL for f-string evaluation when debug is disabled
316+ if enable_llmapi_debug () or logger .level == 'debug' :
317+ logger_debug (
318+ f"RPC Client received response: request_id={ response .request_id } , "
319+ f"is_streaming={ response .is_streaming } , "
320+ f"pending_futures={ len (self ._pending_futures )} " )
321+
322+ with nvtx_range_debug ("handle_response" ,
323+ color = "purple" ,
324+ category = "RPC" ):
325+ if response .is_streaming :
326+ self ._handle_streaming_response (response )
327+ else :
328+ self ._handle_regular_response (response )
329+
330+ except asyncio .CancelledError :
331+ # Still handle cancellation for backward compatibility
332+ logger_debug ("Response reader cancelled" )
333+ break
334+ except Exception as e :
335+ await self ._handle_reader_exception (e )
336+ break
337+
338+ # Clean up any pending get task
339+ if self ._pending_get_task and not self ._pending_get_task .done ():
340+ self ._pending_get_task .cancel ()
341+ try :
342+ await self ._pending_get_task
343+ except asyncio .CancelledError :
344+ pass # Expected
345+ self ._pending_get_task = None
297346
298347 logger_debug ("Response reader exiting gracefully" )
299348 self ._reader_task = None
0 commit comments