Skip to content

Commit 2bcafb3

Browse files
committed
simplify rpc_client await_response
Signed-off-by: Superjomn <[email protected]>
1 parent c703c1c commit 2bcafb3

File tree

1 file changed

+87
-105
lines changed

1 file changed

+87
-105
lines changed

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 87 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,9 @@ def __init__(self,
107107

108108
self._server_stopped = False
109109
self._closed = False
110-
self._stop_event = None
111110
self._loop = None
112111
self._loop_thread = None
113-
self._pending_get_task = None # Track pending socket read to avoid cancellation
112+
self._reader_asyncio_task = None # Track the asyncio task for proper cancellation
114113

115114
logger_debug(f"RPC Client initialized. Connected to {self._address}")
116115

@@ -128,39 +127,51 @@ def close(self):
128127

129128
if self._closed:
130129
return
131-
# stop the main loop
132130
self._closed = True
133131

134132
logger_debug("RPC Client closing")
135133

136-
if self._stop_event and self._loop:
137-
# Use call_soon_threadsafe since set() is not a coroutine
138-
self._loop.call_soon_threadsafe(self._stop_event.set)
139-
140-
if self._reader_task:
141-
try:
142-
self._reader_task.result(timeout=1.0)
143-
except concurrent.futures.TimeoutError:
144-
logger.warning(
145-
"Reader task did not exit gracefully, cancelling")
146-
self._reader_task.cancel()
147-
except Exception as e:
148-
# Task might have already finished or been cancelled
149-
logger_debug(f"Reader task cleanup: {e}")
134+
# Cancel the reader task first to avoid socket closure errors
135+
if self._reader_task and not self._reader_task.done():
136+
if self._loop and self._loop.is_running(
137+
) and self._reader_asyncio_task:
138+
try:
139+
# Cancel the asyncio task in its event loop
140+
async def cancel_reader_task():
141+
if self._reader_asyncio_task and not self._reader_asyncio_task.done(
142+
):
143+
self._reader_asyncio_task.cancel()
144+
try:
145+
await self._reader_asyncio_task
146+
except asyncio.CancelledError:
147+
pass # Expected
148+
149+
cancel_future = asyncio.run_coroutine_threadsafe(
150+
cancel_reader_task(), self._loop)
151+
cancel_future.result(timeout=2.0)
152+
logger_debug("Reader task cancelled successfully")
153+
except concurrent.futures.TimeoutError:
154+
logger.warning("Reader task did not exit gracefully")
155+
except Exception as e:
156+
logger_debug(f"Reader task cleanup: {e}")
150157
self._reader_task = None
158+
self._reader_asyncio_task = None
151159

160+
# Now close the socket after reader has stopped
161+
if self._client_socket:
162+
self._client_socket.close()
163+
self._client_socket = None
164+
165+
# Stop the event loop
152166
if self._loop and self._loop.is_running():
153167
self._loop.call_soon_threadsafe(self._loop.stop)
154168
if self._loop_thread:
155-
self._loop_thread.join()
169+
self._loop_thread.join(timeout=2.0)
156170
self._loop_thread = None
171+
157172
if self._executor:
158173
self._executor.shutdown(wait=True)
159174

160-
if self._client_socket:
161-
self._client_socket.close()
162-
self._client_socket = None
163-
164175
logger_debug("RPC Client closed")
165176

166177
def _handle_streaming_response(self, response: RPCResponse):
@@ -259,101 +270,73 @@ def safe_set_exception(f=future, exc=exception):
259270
for queue in self._streaming_queues.values():
260271
await queue.put(RPCResponse("", None, exception, False, 0, 'error'))
261272

262-
async def _wait_for_response(self) -> Optional[RPCResponse]:
263-
"""Wait for a response from the socket with timeout.
273+
async def _wait_for_response(self) -> RPCResponse:
274+
"""Wait for a response from the socket.
264275
265276
Returns:
266-
RPCResponse if available, None if timeout
277+
RPCResponse from the server
267278
"""
268-
# Reuse pending task or create new one to avoid cancelling ZMQ operations
269-
if self._pending_get_task is None or self._pending_get_task.done():
270-
self._pending_get_task = asyncio.create_task(
271-
self._client_socket.get_async())
272-
273-
try:
274-
# Use wait with a done callback instead of wait_for to avoid cancellation
275-
done, pending = await asyncio.wait(
276-
[self._pending_get_task],
277-
timeout=0.1, # Check stop event every 100ms
278-
return_when=asyncio.FIRST_COMPLETED)
279-
280-
if done:
281-
# Task completed, get the result
282-
response = await self._pending_get_task
283-
self._pending_get_task = None # Clear so we create a new one next time
284-
return response
285-
else:
286-
# Timeout - task is still pending, will be reused next iteration
287-
return None
288-
289-
except Exception as e:
290-
# If there's an error, clear the task
291-
self._pending_get_task = None
292-
raise e
279+
# Directly await the socket - cancellation will be handled by task cancellation
280+
return await self._client_socket.get_async()
293281

294282
async def _response_reader(self):
295283
"""Task to read responses from the socket and set results on futures."""
296284
logger_debug("Response reader started")
297285

298-
with customized_gc_thresholds(10000):
299-
while not self._stop_event.is_set():
300-
with nvtx_range_debug("response_reader",
301-
color="cyan",
302-
category="RPC"):
303-
try:
304-
response = await self._wait_for_response()
305-
306-
if response is None:
307-
continue
308-
309-
nvtx_mark_debug(
310-
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
311-
color="black",
312-
category="RPC")
313-
314-
# Optimize: Check debug flag before expensive string operations
315-
# This avoids holding GIL for f-string evaluation when debug is disabled
316-
if enable_llmapi_debug() or logger.level == 'debug':
317-
logger_debug(
318-
f"RPC Client received response: request_id={response.request_id}, "
319-
f"is_streaming={response.is_streaming}, "
320-
f"pending_futures={len(self._pending_futures)}")
321-
322-
with nvtx_range_debug("handle_response",
323-
color="purple",
324-
category="RPC"):
325-
if response.is_streaming:
326-
self._handle_streaming_response(response)
327-
else:
328-
self._handle_regular_response(response)
329-
330-
except asyncio.CancelledError:
331-
# Still handle cancellation for backward compatibility
332-
logger_debug("Response reader cancelled")
333-
break
334-
except Exception as e:
335-
await self._handle_reader_exception(e)
336-
break
337-
338-
# Clean up any pending get task
339-
if self._pending_get_task and not self._pending_get_task.done():
340-
self._pending_get_task.cancel()
341-
try:
342-
await self._pending_get_task
343-
except asyncio.CancelledError:
344-
pass # Expected
345-
self._pending_get_task = None
346-
347-
logger_debug("Response reader exiting gracefully")
348-
self._reader_task = None
286+
try:
287+
with customized_gc_thresholds(10000):
288+
while True:
289+
with nvtx_range_debug("response_reader",
290+
color="cyan",
291+
category="RPC"):
292+
try:
293+
response = await self._wait_for_response()
294+
295+
nvtx_mark_debug(
296+
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
297+
color="black",
298+
category="RPC")
299+
300+
# Optimize: Check debug flag before expensive string operations
301+
# This avoids holding GIL for f-string evaluation when debug is disabled
302+
if enable_llmapi_debug() or logger.level == 'debug':
303+
logger_debug(
304+
f"RPC Client received response: request_id={response.request_id}, "
305+
f"is_streaming={response.is_streaming}, "
306+
f"pending_futures={len(self._pending_futures)}"
307+
)
308+
309+
with nvtx_range_debug("handle_response",
310+
color="purple",
311+
category="RPC"):
312+
if response.is_streaming:
313+
self._handle_streaming_response(response)
314+
else:
315+
self._handle_regular_response(response)
316+
317+
except Exception as e:
318+
await self._handle_reader_exception(e)
319+
break
320+
321+
except asyncio.CancelledError:
322+
logger_debug("Response reader cancelled")
323+
finally:
324+
logger_debug("Response reader exiting gracefully")
325+
self._reader_task = None
326+
self._reader_asyncio_task = None
349327

350328
def _start_response_reader_lazily(self):
351329
if self._reader_task is None or self._reader_task.done():
352330
# Ensure we have a persistent background loop
353331
self._ensure_event_loop()
354-
# Always create the reader task on the persistent loop
355-
future = asyncio.run_coroutine_threadsafe(self._response_reader(),
356-
self._loop)
332+
333+
# Wrapper to track the asyncio task
334+
async def run_reader():
335+
self._reader_asyncio_task = asyncio.current_task()
336+
await self._response_reader()
337+
338+
# Start the reader task on the persistent loop
339+
future = asyncio.run_coroutine_threadsafe(run_reader(), self._loop)
357340
# Store the concurrent.futures.Future
358341
self._reader_task = future
359342

@@ -425,7 +408,6 @@ def _ensure_event_loop(self):
425408

426409
def run_loop():
427410
asyncio.set_event_loop(self._loop)
428-
self._stop_event = asyncio.Event()
429411
self._loop.run_forever()
430412

431413
self._loop_thread = threading.Thread(target=run_loop,

0 commit comments

Comments
 (0)