Skip to content

Commit 1915b16

Browse files
committed
better reuse code between remote methods
Signed-off-by: chunweiy <[email protected]> Signed-off-by: chunweiy <[email protected]>
1 parent 57b986d commit 1915b16

File tree

3 files changed

+108
-41
lines changed

3 files changed

+108
-41
lines changed

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,52 @@ def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs):
2020
self.args = args
2121
self.kwargs = kwargs
2222

23+
def _prepare_and_call(self, timeout: Optional[float], need_response: bool,
24+
mode: str, call_method: str) -> Any:
25+
"""Common method to prepare RPC params and make the call.
26+
27+
Args:
28+
timeout: Timeout for the RPC call
29+
need_response: Whether a response is expected
30+
mode: The RPC mode ("sync", "async", "future")
31+
call_method: The method name to call on the client
32+
33+
Returns:
34+
The result of the client method call
35+
"""
36+
rpc_params = RPCParams(timeout=timeout,
37+
need_response=need_response,
38+
mode=mode)
39+
self.kwargs["__rpc_params"] = rpc_params
40+
client_method = getattr(self.client, call_method)
41+
return client_method(self.method_name, *self.args, **self.kwargs)
42+
2343
def remote(self,
2444
timeout: Optional[float] = None,
2545
need_response: bool = True) -> Any:
2646
"""Synchronous remote call with optional RPC parameters."""
27-
rpc_params = RPCParams(timeout=timeout,
28-
need_response=need_response,
29-
mode="sync")
30-
self.kwargs["__rpc_params"] = rpc_params
31-
return self.client._call_sync(self.method_name, *self.args,
32-
**self.kwargs)
47+
return self._prepare_and_call(timeout, need_response, "sync",
48+
"_call_sync")
3349

3450
def remote_async(self,
3551
timeout: Optional[float] = None,
3652
need_response: bool = True):
3753
"""Asynchronous remote call that returns a coroutine."""
38-
rpc_params = RPCParams(timeout=timeout,
39-
need_response=need_response,
40-
mode="async")
41-
self.kwargs["__rpc_params"] = rpc_params
42-
return self.client._call_async(self.method_name, *self.args,
43-
**self.kwargs)
54+
return self._prepare_and_call(timeout, need_response, "async",
55+
"_call_async")
4456

4557
def remote_future(self,
4658
timeout: Optional[float] = None,
4759
need_response: bool = True) -> concurrent.futures.Future:
4860
"""Remote call that returns a Future object."""
49-
rpc_params = RPCParams(timeout=timeout,
50-
need_response=need_response,
51-
mode="future")
52-
self.kwargs["__rpc_params"] = rpc_params
53-
return self.client.call_future(self.method_name, *self.args,
54-
**self.kwargs)
61+
return self._prepare_and_call(timeout, need_response, "future",
62+
"call_future")
5563

5664
def remote_streaming(self,
5765
timeout: Optional[float] = None) -> AsyncIterator[Any]:
5866
"""Remote call for streaming results."""
59-
rpc_params = RPCParams(timeout=timeout,
60-
need_response=True,
61-
mode="async")
62-
self.kwargs["__rpc_params"] = rpc_params
63-
return self.client.call_streaming(self.method_name, *self.args,
64-
**self.kwargs)
67+
# Streaming always needs a response
68+
return self._prepare_and_call(timeout, True, "async", "call_streaming")
6569

6670

6771
class RPCClient:
@@ -309,7 +313,7 @@ async def _call_async(self, method_name, *args, **kwargs):
309313
res = await future
310314
else:
311315
# Add 1 second to the timeout to ensure the client can get
312-
res = await asyncio.wait_for(future, timeout + 1)
316+
res = await asyncio.wait_for(future, timeout)
313317
logger_debug(
314318
f"RPC Client _call_async: Got result for request_id: {request_id}: {res}"
315319
)
@@ -361,7 +365,7 @@ def _call_sync(self, method_name, *args, **kwargs):
361365
f"RPC Client _call_sync: Got result for {method_name}: {result}")
362366
return result
363367

364-
def call_async(self, name: str, *args, **kwargs):
368+
def call_async(self, name: str, *args, **kwargs) -> Any:
365369
"""
366370
Call a remote method asynchronously.
367371
@@ -408,7 +412,7 @@ def _async_to_sync():
408412

409413
return self._executor.submit(_async_to_sync)
410414

411-
def call_sync(self, name: str, *args, **kwargs):
415+
def call_sync(self, name: str, *args, **kwargs) -> Any:
412416
"""
413417
Call a remote method synchronously (blocking).
414418
@@ -476,7 +480,7 @@ async def call_streaming(self, name: str, *args,
476480
response = await queue.get()
477481
else:
478482
response = await asyncio.wait_for(queue.get(),
479-
timeout=timeout + 1)
483+
timeout=timeout)
480484

481485
logger_debug(
482486
f"RPC Client call_streaming received [{response.stream_status}] response: {response}",

tensorrt_llm/executor/rpc/rpc_common.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import time
2+
from dataclasses import dataclass
13
from typing import Any, Literal, NamedTuple, Optional
24

35

@@ -48,14 +50,22 @@ class RPCStreamingError(RPCError):
4850
"""Exception for streaming-related errors."""
4951

5052

51-
class RPCRequest(NamedTuple):
53+
@dataclass
54+
class RPCRequest:
5255
request_id: str
5356
method_name: str
5457
args: tuple
5558
kwargs: dict
5659
need_response: bool = True
5760
timeout: float = 0.5
5861
is_streaming: bool = False
62+
creation_timestamp: Optional[
63+
float] = None # Unix timestamp when request was created
64+
65+
def __post_init__(self):
66+
"""Initialize creation_timestamp if not provided."""
67+
if self.creation_timestamp is None:
68+
self.creation_timestamp = time.time()
5969

6070

6171
class RPCResponse(NamedTuple):

tensorrt_llm/executor/rpc/rpc_server.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -292,18 +292,46 @@ async def _worker_routine(self, stop_event: threading.Event):
292292
if req.method_name not in ["_rpc_shutdown", "shutdown"]:
293293
self._num_pending_requests -= 1
294294

295+
def _calculate_adjusted_timeout(self,
296+
req: RPCRequest,
297+
is_streaming: bool = False) -> float:
298+
"""Calculate adjusted timeout based on pending overhead.
299+
300+
Args:
301+
req: The RPC request
302+
is_streaming: Whether this is for a streaming request
303+
304+
Returns:
305+
The adjusted timeout value
306+
"""
307+
adjusted_timeout = req.timeout
308+
if req.creation_timestamp is not None and req.timeout is not None and req.timeout > 0:
309+
pending_time = time.time() - req.creation_timestamp
310+
adjusted_timeout = max(0.1, req.timeout -
311+
pending_time) # Keep at least 0.1s timeout
312+
if pending_time > 0.1: # Only log if significant pending time
313+
method_type = "streaming " if is_streaming else ""
314+
logger_debug(
315+
f"RPC Server adjusted timeout for {method_type}{req.method_name}: "
316+
f"original={req.timeout}s, pending={pending_time:.3f}s, adjusted={adjusted_timeout:.3f}s"
317+
)
318+
return adjusted_timeout
319+
295320
async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]:
296321
"""Process a request. Returns None for streaming requests (handled separately)."""
297322
func = self._functions[req.method_name]
298323

324+
# Calculate adjusted timeout based on pending overhead
325+
adjusted_timeout = self._calculate_adjusted_timeout(req)
326+
299327
try:
300328
if inspect.iscoroutinefunction(func):
301329
# Execute async function directly in event loop, no need to run in executor due to the GIL
302330
logger_debug(
303331
f"RPC Server running async task {req.method_name} in dispatcher"
304332
)
305333
result = await asyncio.wait_for(func(*req.args, **req.kwargs),
306-
timeout=req.timeout)
334+
timeout=adjusted_timeout)
307335
else:
308336
# Execute sync function in thread executor
309337
loop = asyncio.get_running_loop()
@@ -317,7 +345,7 @@ def call_with_kwargs():
317345
# TODO: let num worker control the pool size
318346
result = await asyncio.wait_for(loop.run_in_executor(
319347
self._executor, call_with_kwargs),
320-
timeout=req.timeout)
348+
timeout=adjusted_timeout)
321349

322350
logger_debug(f"RPC Server returned result for request {req}")
323351
response = RPCResponse(req.request_id, result)
@@ -354,6 +382,10 @@ async def _process_streaming_request(self, req: RPCRequest):
354382

355383
sequence_number = 0
356384

385+
# Calculate adjusted timeout based on pending overhead
386+
adjusted_timeout = self._calculate_adjusted_timeout(req,
387+
is_streaming=True)
388+
357389
try:
358390
logger_debug(f"RPC Server running streaming task {req.method_name}")
359391
# Send start signal
@@ -362,16 +394,37 @@ async def _process_streaming_request(self, req: RPCRequest):
362394
'start'))
363395
sequence_number += 1
364396

365-
# Stream the results
366-
async for result in func(*req.args, **req.kwargs):
367-
logger_debug(
368-
f"RPC Server got data and ready to send result {result}")
369-
response = RPCResponse(req.request_id, result, None, True,
370-
sequence_number, 'data')
371-
if not await self._send_response(req, response):
372-
# Stop streaming after a pickle error
373-
return
374-
sequence_number += 1
397+
# Apply timeout to the entire streaming operation if specified
398+
if adjusted_timeout is not None and adjusted_timeout > 0:
399+
# Create a task for the async generator with timeout
400+
async def stream_with_timeout():
401+
nonlocal sequence_number
402+
async for result in func(*req.args, **req.kwargs):
403+
logger_debug(
404+
f"RPC Server got data and ready to send result {result}"
405+
)
406+
response = RPCResponse(req.request_id, result, None,
407+
True, sequence_number, 'data')
408+
if not await self._send_response(req, response):
409+
# Stop streaming after a pickle error
410+
return
411+
sequence_number += 1
412+
413+
# Use wait_for for timeout handling
414+
await asyncio.wait_for(stream_with_timeout(),
415+
timeout=adjusted_timeout)
416+
else:
417+
# No timeout specified, stream normally
418+
async for result in func(*req.args, **req.kwargs):
419+
logger_debug(
420+
f"RPC Server got data and ready to send result {result}"
421+
)
422+
response = RPCResponse(req.request_id, result, None, True,
423+
sequence_number, 'data')
424+
if not await self._send_response(req, response):
425+
# Stop streaming after a pickle error
426+
return
427+
sequence_number += 1
375428

376429
# Send end signal
377430
await self._client_socket.put_async(

0 commit comments

Comments
 (0)