@@ -107,10 +107,9 @@ def __init__(self,
107107
108108 self ._server_stopped = False
109109 self ._closed = False
110- self ._stop_event = None
111110 self ._loop = None
112111 self ._loop_thread = None
113- self ._pending_get_task = None # Track pending socket read to avoid cancellation
112+ self ._reader_asyncio_task = None # Track the asyncio task for proper cancellation
114113
115114 logger_debug (f"RPC Client initialized. Connected to { self ._address } " )
116115
@@ -128,39 +127,51 @@ def close(self):
128127
129128 if self ._closed :
130129 return
131- # stop the main loop
132130 self ._closed = True
133131
134132 logger_debug ("RPC Client closing" )
135133
136- if self ._stop_event and self ._loop :
137- # Use call_soon_threadsafe since set() is not a coroutine
138- self ._loop .call_soon_threadsafe (self ._stop_event .set )
139-
140- if self ._reader_task :
141- try :
142- self ._reader_task .result (timeout = 1.0 )
143- except concurrent .futures .TimeoutError :
144- logger .warning (
145- "Reader task did not exit gracefully, cancelling" )
146- self ._reader_task .cancel ()
147- except Exception as e :
148- # Task might have already finished or been cancelled
149- logger_debug (f"Reader task cleanup: { e } " )
134+ # Cancel the reader task first to avoid socket closure errors
135+ if self ._reader_task and not self ._reader_task .done ():
136+ if self ._loop and self ._loop .is_running (
137+ ) and self ._reader_asyncio_task :
138+ try :
139+ # Cancel the asyncio task in its event loop
140+ async def cancel_reader_task ():
141+ if self ._reader_asyncio_task and not self ._reader_asyncio_task .done (
142+ ):
143+ self ._reader_asyncio_task .cancel ()
144+ try :
145+ await self ._reader_asyncio_task
146+ except asyncio .CancelledError :
147+ pass # Expected
148+
149+ cancel_future = asyncio .run_coroutine_threadsafe (
150+ cancel_reader_task (), self ._loop )
151+ cancel_future .result (timeout = 2.0 )
152+ logger_debug ("Reader task cancelled successfully" )
153+ except concurrent .futures .TimeoutError :
154+ logger .warning ("Reader task did not exit gracefully" )
155+ except Exception as e :
156+ logger_debug (f"Reader task cleanup: { e } " )
150157 self ._reader_task = None
158+ self ._reader_asyncio_task = None
151159
160+ # Now close the socket after reader has stopped
161+ if self ._client_socket :
162+ self ._client_socket .close ()
163+ self ._client_socket = None
164+
165+ # Stop the event loop
152166 if self ._loop and self ._loop .is_running ():
153167 self ._loop .call_soon_threadsafe (self ._loop .stop )
154168 if self ._loop_thread :
155- self ._loop_thread .join ()
169+ self ._loop_thread .join (timeout = 2.0 )
156170 self ._loop_thread = None
171+
157172 if self ._executor :
158173 self ._executor .shutdown (wait = True )
159174
160- if self ._client_socket :
161- self ._client_socket .close ()
162- self ._client_socket = None
163-
164175 logger_debug ("RPC Client closed" )
165176
166177 def _handle_streaming_response (self , response : RPCResponse ):
@@ -259,101 +270,73 @@ def safe_set_exception(f=future, exc=exception):
259270 for queue in self ._streaming_queues .values ():
260271 await queue .put (RPCResponse ("" , None , exception , False , 0 , 'error' ))
261272
262- async def _wait_for_response (self ) -> Optional [ RPCResponse ] :
263- """Wait for a response from the socket with timeout .
273+ async def _wait_for_response (self ) -> RPCResponse :
274+ """Wait for a response from the socket.
264275
265276 Returns:
266- RPCResponse if available, None if timeout
277+ RPCResponse from the server
267278 """
268- # Reuse pending task or create new one to avoid cancelling ZMQ operations
269- if self ._pending_get_task is None or self ._pending_get_task .done ():
270- self ._pending_get_task = asyncio .create_task (
271- self ._client_socket .get_async ())
272-
273- try :
274- # Use wait with a done callback instead of wait_for to avoid cancellation
275- done , pending = await asyncio .wait (
276- [self ._pending_get_task ],
277- timeout = 0.1 , # Check stop event every 100ms
278- return_when = asyncio .FIRST_COMPLETED )
279-
280- if done :
281- # Task completed, get the result
282- response = await self ._pending_get_task
283- self ._pending_get_task = None # Clear so we create a new one next time
284- return response
285- else :
286- # Timeout - task is still pending, will be reused next iteration
287- return None
288-
289- except Exception as e :
290- # If there's an error, clear the task
291- self ._pending_get_task = None
292- raise e
279+ # Directly await the socket - cancellation will be handled by task cancellation
280+ return await self ._client_socket .get_async ()
293281
294282 async def _response_reader (self ):
295283 """Task to read responses from the socket and set results on futures."""
296284 logger_debug ("Response reader started" )
297285
298- with customized_gc_thresholds (10000 ):
299- while not self ._stop_event .is_set ():
300- with nvtx_range_debug ("response_reader" ,
301- color = "cyan" ,
302- category = "RPC" ):
303- try :
304- response = await self ._wait_for_response ()
305-
306- if response is None :
307- continue
308-
309- nvtx_mark_debug (
310- f"RPC.response.{ 'streaming' if response .is_streaming else 'sync' } " ,
311- color = "black" ,
312- category = "RPC" )
313-
314- # Optimize: Check debug flag before expensive string operations
315- # This avoids holding GIL for f-string evaluation when debug is disabled
316- if enable_llmapi_debug () or logger .level == 'debug' :
317- logger_debug (
318- f"RPC Client received response: request_id={ response .request_id } , "
319- f"is_streaming={ response .is_streaming } , "
320- f"pending_futures={ len (self ._pending_futures )} " )
321-
322- with nvtx_range_debug ("handle_response" ,
323- color = "purple" ,
324- category = "RPC" ):
325- if response .is_streaming :
326- self ._handle_streaming_response (response )
327- else :
328- self ._handle_regular_response (response )
329-
330- except asyncio .CancelledError :
331- # Still handle cancellation for backward compatibility
332- logger_debug ("Response reader cancelled" )
333- break
334- except Exception as e :
335- await self ._handle_reader_exception (e )
336- break
337-
338- # Clean up any pending get task
339- if self ._pending_get_task and not self ._pending_get_task .done ():
340- self ._pending_get_task .cancel ()
341- try :
342- await self ._pending_get_task
343- except asyncio .CancelledError :
344- pass # Expected
345- self ._pending_get_task = None
346-
347- logger_debug ("Response reader exiting gracefully" )
348- self ._reader_task = None
286+ try :
287+ with customized_gc_thresholds (10000 ):
288+ while True :
289+ with nvtx_range_debug ("response_reader" ,
290+ color = "cyan" ,
291+ category = "RPC" ):
292+ try :
293+ response = await self ._wait_for_response ()
294+
295+ nvtx_mark_debug (
296+ f"RPC.response.{ 'streaming' if response .is_streaming else 'sync' } " ,
297+ color = "black" ,
298+ category = "RPC" )
299+
300+ # Optimize: Check debug flag before expensive string operations
301+ # This avoids holding GIL for f-string evaluation when debug is disabled
302+ if enable_llmapi_debug () or logger .level == 'debug' :
303+ logger_debug (
304+ f"RPC Client received response: request_id={ response .request_id } , "
305+ f"is_streaming={ response .is_streaming } , "
306+ f"pending_futures={ len (self ._pending_futures )} "
307+ )
308+
309+ with nvtx_range_debug ("handle_response" ,
310+ color = "purple" ,
311+ category = "RPC" ):
312+ if response .is_streaming :
313+ self ._handle_streaming_response (response )
314+ else :
315+ self ._handle_regular_response (response )
316+
317+ except Exception as e :
318+ await self ._handle_reader_exception (e )
319+ break
320+
321+ except asyncio .CancelledError :
322+ logger_debug ("Response reader cancelled" )
323+ finally :
324+ logger_debug ("Response reader exiting gracefully" )
325+ self ._reader_task = None
326+ self ._reader_asyncio_task = None
349327
350328 def _start_response_reader_lazily (self ):
351329 if self ._reader_task is None or self ._reader_task .done ():
352330 # Ensure we have a persistent background loop
353331 self ._ensure_event_loop ()
354- # Always create the reader task on the persistent loop
355- future = asyncio .run_coroutine_threadsafe (self ._response_reader (),
356- self ._loop )
332+
333+ # Wrapper to track the asyncio task
334+ async def run_reader ():
335+ self ._reader_asyncio_task = asyncio .current_task ()
336+ await self ._response_reader ()
337+
338+ # Start the reader task on the persistent loop
339+ future = asyncio .run_coroutine_threadsafe (run_reader (), self ._loop )
357340 # Store the concurrent.futures.Future
358341 self ._reader_task = future
359342
@@ -425,7 +408,6 @@ def _ensure_event_loop(self):
425408
426409 def run_loop ():
427410 asyncio .set_event_loop (self ._loop )
428- self ._stop_event = asyncio .Event ()
429411 self ._loop .run_forever ()
430412
431413 self ._loop_thread = threading .Thread (target = run_loop ,
0 commit comments