Skip to content

Commit c703c1c

Browse files
committed
.fix cancel future in zmq
Signed-off-by: Superjomn <[email protected]>
1 parent 1a6cae3 commit c703c1c

File tree

2 files changed

+109
-61
lines changed

2 files changed

+109
-61
lines changed

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 108 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import zmq
88

9-
from tensorrt_llm._utils import nvtx_mark_debug
9+
from tensorrt_llm._utils import (customized_gc_thresholds, nvtx_mark_debug,
10+
nvtx_range_debug)
1011

11-
from ..._utils import nvtx_range_debug
1212
from ...llmapi.utils import (AsyncQueue, _SyncQueue, enable_llmapi_debug,
1313
logger_debug)
1414
from ...logger import logger
@@ -110,6 +110,7 @@ def __init__(self,
110110
self._stop_event = None
111111
self._loop = None
112112
self._loop_thread = None
113+
self._pending_get_task = None # Track pending socket read to avoid cancellation
113114

114115
logger_debug(f"RPC Client initialized. Connected to {self._address}")
115116

@@ -196,21 +197,34 @@ def _handle_regular_response(self, response: RPCResponse):
196197
future, target_loop = future_info
197198

198199
if not future.done():
199-
if response.error is None:
200-
if enable_llmapi_debug() or logger.level == 'debug':
200+
201+
def safe_set_result():
202+
"""Safely set result on future, handling race conditions."""
203+
try:
204+
if not future.done():
205+
if response.error is None:
206+
future.set_result(response.result)
207+
else:
208+
future.set_exception(response.error)
209+
except asyncio.InvalidStateError:
210+
# Future was cancelled or completed between the check and set
211+
# This is expected in high-load scenarios, just log and continue
212+
if enable_llmapi_debug() or logger.level == 'debug':
213+
logger_debug(
214+
f"Future already done for request_id: {response.request_id}, skipping"
215+
)
216+
217+
if enable_llmapi_debug() or logger.level == 'debug':
218+
if response.error is None:
201219
logger_debug(
202220
f"Setting result for request_id: {response.request_id}"
203221
)
204-
target_loop.call_soon_threadsafe(future.set_result,
205-
response.result)
206-
else:
207-
# Use the original RPCError from the response
208-
if enable_llmapi_debug() or logger.level == 'debug':
222+
else:
209223
logger_debug(
210224
f"Setting exception for request_id: {response.request_id}, error: {response.error}"
211225
)
212-
target_loop.call_soon_threadsafe(future.set_exception,
213-
response.error)
226+
227+
target_loop.call_soon_threadsafe(safe_set_result)
214228
else:
215229
if enable_llmapi_debug() or logger.level == 'debug':
216230
logger_debug(
@@ -229,8 +243,17 @@ async def _handle_reader_exception(self, exception: Exception):
229243
# Propagate exception to all pending futures
230244
for (future, target_loop) in self._pending_futures.values():
231245
if not future.done():
232-
target_loop.call_soon_threadsafe(future.set_exception,
233-
exception)
246+
247+
def safe_set_exception(f=future, exc=exception):
248+
"""Safely set exception on future, handling race conditions."""
249+
try:
250+
if not f.done():
251+
f.set_exception(exc)
252+
except asyncio.InvalidStateError:
253+
# Future was cancelled or completed, this is fine
254+
pass
255+
256+
target_loop.call_soon_threadsafe(safe_set_exception)
234257

235258
# Also signal error to streaming queues
236259
for queue in self._streaming_queues.values():
@@ -242,58 +265,84 @@ async def _wait_for_response(self) -> Optional[RPCResponse]:
242265
Returns:
243266
RPCResponse if available, None if timeout
244267
"""
268+
# Reuse pending task or create new one to avoid cancelling ZMQ operations
269+
if self._pending_get_task is None or self._pending_get_task.done():
270+
self._pending_get_task = asyncio.create_task(
271+
self._client_socket.get_async())
272+
245273
try:
246-
response: RPCResponse = await asyncio.wait_for(
247-
self._client_socket.get_async(),
248-
timeout=0.1 # Check stop event every 100ms
249-
)
250-
return response
251-
except asyncio.TimeoutError:
252-
# Timeout is expected - just check stop event and continue
253-
return None
274+
# Use wait with a done callback instead of wait_for to avoid cancellation
275+
done, pending = await asyncio.wait(
276+
[self._pending_get_task],
277+
timeout=0.1, # Check stop event every 100ms
278+
return_when=asyncio.FIRST_COMPLETED)
279+
280+
if done:
281+
# Task completed, get the result
282+
response = await self._pending_get_task
283+
self._pending_get_task = None # Clear so we create a new one next time
284+
return response
285+
else:
286+
# Timeout - task is still pending, will be reused next iteration
287+
return None
288+
289+
except Exception as e:
290+
# If there's an error, clear the task
291+
self._pending_get_task = None
292+
raise e
254293

255294
async def _response_reader(self):
256295
"""Task to read responses from the socket and set results on futures."""
257296
logger_debug("Response reader started")
258297

259-
while not self._stop_event.is_set():
260-
with nvtx_range_debug("response_reader",
261-
color="cyan",
262-
category="RPC"):
263-
try:
264-
response = await self._wait_for_response()
265-
266-
if response is None:
267-
continue
268-
269-
nvtx_mark_debug(
270-
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
271-
color="black",
272-
category="RPC")
273-
274-
# Optimize: Check debug flag before expensive string operations
275-
# This avoids holding GIL for f-string evaluation when debug is disabled
276-
if enable_llmapi_debug() or logger.level == 'debug':
277-
logger_debug(
278-
f"RPC Client received response: request_id={response.request_id}, "
279-
f"is_streaming={response.is_streaming}, "
280-
f"pending_futures={len(self._pending_futures)}")
281-
282-
with nvtx_range_debug("handle_response",
283-
color="purple",
284-
category="RPC"):
285-
if response.is_streaming:
286-
self._handle_streaming_response(response)
287-
else:
288-
self._handle_regular_response(response)
289-
290-
except asyncio.CancelledError:
291-
# Still handle cancellation for backward compatibility
292-
logger_debug("Response reader cancelled")
293-
break
294-
except Exception as e:
295-
await self._handle_reader_exception(e)
296-
break
298+
with customized_gc_thresholds(10000):
299+
while not self._stop_event.is_set():
300+
with nvtx_range_debug("response_reader",
301+
color="cyan",
302+
category="RPC"):
303+
try:
304+
response = await self._wait_for_response()
305+
306+
if response is None:
307+
continue
308+
309+
nvtx_mark_debug(
310+
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
311+
color="black",
312+
category="RPC")
313+
314+
# Optimize: Check debug flag before expensive string operations
315+
# This avoids holding GIL for f-string evaluation when debug is disabled
316+
if enable_llmapi_debug() or logger.level == 'debug':
317+
logger_debug(
318+
f"RPC Client received response: request_id={response.request_id}, "
319+
f"is_streaming={response.is_streaming}, "
320+
f"pending_futures={len(self._pending_futures)}")
321+
322+
with nvtx_range_debug("handle_response",
323+
color="purple",
324+
category="RPC"):
325+
if response.is_streaming:
326+
self._handle_streaming_response(response)
327+
else:
328+
self._handle_regular_response(response)
329+
330+
except asyncio.CancelledError:
331+
# Still handle cancellation for backward compatibility
332+
logger_debug("Response reader cancelled")
333+
break
334+
except Exception as e:
335+
await self._handle_reader_exception(e)
336+
break
337+
338+
# Clean up any pending get task
339+
if self._pending_get_task and not self._pending_get_task.done():
340+
self._pending_get_task.cancel()
341+
try:
342+
await self._pending_get_task
343+
except asyncio.CancelledError:
344+
pass # Expected
345+
self._pending_get_task = None
297346

298347
logger_debug("Response reader exiting gracefully")
299348
self._reader_task = None

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ def __init__(self,
128128
self._orchestrator_type = kwargs.get("orchestrator_type", None)
129129
self._llm_id = None
130130

131-
# Enable GC NVTX profiling if environment variable is set
132-
self.gc_nvtx_watcher_handle = gc_nvtx_watcher()
131+
self._gc_nvtx_watcher_handle = gc_nvtx_watcher()
133132

134133
log_level = logger.level
135134
logger.set_level("info") # force display the backend

0 commit comments

Comments
 (0)