Skip to content

Commit 38959b1

Browse files
committed
refactor RPC param
Signed-off-by: Superjomn <[email protected]>
1 parent 9995d02 commit 38959b1

File tree

7 files changed

+83
-48
lines changed

7 files changed

+83
-48
lines changed
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .rpc_client import RPCClient
2-
from .rpc_common import (RPCCancelled, RPCError, RPCRequest, RPCResponse,
3-
RPCStreamingError, RPCTimeout)
2+
from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest,
3+
RPCResponse, RPCStreamingError, RPCTimeout)
44
from .rpc_server import RPCServer, Server
55

66
__all__ = [
77
"RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout",
8-
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse"
8+
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse",
9+
"RPCParams"
910
]

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ...logger import logger
88
from ..ipc import ZeroMqQueue
9-
from .rpc_common import (RPCCancelled, RPCRequest, RPCResponse,
9+
from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse,
1010
RPCStreamingError, RPCTimeout)
1111

1212

@@ -164,33 +164,32 @@ def _start_response_reader_lazily(self):
164164
# Store the concurrent.futures.Future
165165
self._reader_task = future
166166

167-
async def _call_async(self, __rpc_method_name, *args, **kwargs):
167+
async def _call_async(self, method_name, *args, **kwargs):
168168
"""Async version of RPC call.
169169
Args:
170-
__rpc_method_name: Method name to call
170+
method_name: Method name to call
171171
*args: Positional arguments
172172
**kwargs: Keyword arguments
173-
__rpc_timeout: The timeout (seconds) for the RPC call.
174-
__rpc_need_response: Whether the RPC call needs a response.
175-
If set to False, the remote call will return immediately.
173+
__rpc_params: RPCParams object containing RPC parameters.
176174
177175
Returns:
178176
The result of the remote method call
179177
"""
180178
logger.debug(
181-
f"RPC client calling method: {__rpc_method_name} with args: {args} and kwargs: {kwargs}"
179+
f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
182180
)
183181
if self._server_stopped:
184182
raise RPCCancelled("Server is shutting down, request cancelled")
185183

186184
self._start_response_reader_lazily()
187-
need_response = kwargs.pop("__rpc_need_response", True)
188-
timeout = kwargs.pop("__rpc_timeout", self._timeout)
185+
rpc_params = kwargs.pop("__rpc_params", RPCParams())
186+
need_response = rpc_params.need_response
187+
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
189188

190189
request_id = uuid.uuid4().hex
191190
logger.debug(f"RPC client sending request: {request_id}")
192191
request = RPCRequest(request_id,
193-
__rpc_method_name,
192+
method_name,
194193
args,
195194
kwargs,
196195
need_response,
@@ -216,7 +215,7 @@ async def _call_async(self, __rpc_method_name, *args, **kwargs):
216215
raise
217216
except asyncio.TimeoutError:
218217
raise RPCTimeout(
219-
f"Request '{__rpc_method_name}' timed out after {timeout}s")
218+
f"Request '{method_name}' timed out after {timeout}s")
220219
except Exception as e:
221220
raise e
222221
finally:
@@ -241,11 +240,11 @@ def run_loop():
241240
import time
242241
time.sleep(0.1)
243242

244-
def _call_sync(self, __rpc_method_name, *args, **kwargs):
243+
def _call_sync(self, method_name, *args, **kwargs):
245244
"""Synchronous version of RPC call."""
246245
self._ensure_event_loop()
247246
future = asyncio.run_coroutine_threadsafe(
248-
self._call_async(__rpc_method_name, *args, **kwargs), self._loop)
247+
self._call_async(method_name, *args, **kwargs), self._loop)
249248
return future.result()
250249

251250
def call_async(self, name: str, *args, **kwargs):
@@ -263,7 +262,9 @@ def call_async(self, name: str, *args, **kwargs):
263262
Example:
264263
result = await client.call_async('remote_method', arg1, arg2, key=value)
265264
"""
266-
return self._call_async(name, *args, **kwargs, __rpc_need_response=True)
265+
if "__rpc_params" not in kwargs:
266+
kwargs["__rpc_params"] = RPCParams(need_response=True)
267+
return self._call_async(name, *args, **kwargs)
267268

268269
def call_future(self, name: str, *args,
269270
**kwargs) -> concurrent.futures.Future:
@@ -331,7 +332,8 @@ async def call_streaming(self, name: str, *args,
331332
raise RPCCancelled("Server is shutting down, request cancelled")
332333

333334
self._start_response_reader_lazily()
334-
timeout = kwargs.pop("__rpc_timeout", self._timeout)
335+
rpc_params = kwargs.pop("__rpc_params", RPCParams())
336+
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
335337

336338
request_id = uuid.uuid4().hex
337339
queue = asyncio.Queue()
@@ -379,7 +381,9 @@ async def call_streaming(self, name: str, *args,
379381
def get_server_attr(self, name: str):
380382
""" Get the attribute of the RPC server.
381383
This is mainly used for testing. """
382-
return self._call_sync("__rpc_get_attr", name, __rpc_timeout=10)
384+
return self._call_sync("__rpc_get_attr",
385+
name,
386+
__rpc_params=RPCParams(timeout=10))
383387

384388
def __getattr__(self, name):
385389
"""
@@ -395,7 +399,8 @@ def __init__(self, client, method_name):
395399

396400
def __call__(self, *args, **kwargs):
397401
"""Default synchronous call"""
398-
mode = kwargs.pop("__rpc_mode", "sync")
402+
rpc_params = kwargs.get("__rpc_params", RPCParams())
403+
mode = rpc_params.mode
399404
if mode == "sync":
400405
return self.client._call_sync(self.method_name, *args,
401406
**kwargs)

tensorrt_llm/executor/rpc/rpc_common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
from typing import Any, Literal, NamedTuple, Optional
22

33

4+
class RPCParams(NamedTuple):
5+
""" Parameters for RPC calls. """
6+
7+
# seconds to wait for the response
8+
timeout: Optional[float] = None
9+
10+
# whether the client needs the response, if False, it will return immediately
11+
need_response: bool = True
12+
13+
# mode for RPC calls: "sync", "async", or "future"
14+
mode: str = "sync"
15+
16+
417
# --- Custom Exceptions ---
518
class RPCError(Exception):
619
"""Custom exception for RPC-related errors raised on the client side.

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .request import GenerationRequest
1515
from .result import GenerationResult
1616
from .rpc import RPCClient
17+
from .rpc.rpc_common import RPCParams
1718
from .rpc_worker import RpcWorker
1819
from .utils import (ErrorResponse, create_mpi_comm_session,
1920
get_spawn_proxy_process_env, is_llm_response)
@@ -90,7 +91,8 @@ def main_loop_task(self):
9091
clock = 0
9192
while not self._shutdown_event.is_set():
9293
if clock % 1 == 0:
93-
responses = self.fetch_responses_remote()
94+
responses = self.fetch_responses_remote(
95+
) # RPC request => RPC server; result => RPC client
9496
self.handle_responses(responses)
9597
if clock % 10 == 0:
9698
stats = self.fetch_stats_remote() # TODO
@@ -144,7 +146,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
144146
logprob_params = self._get_logprob_params(request)
145147

146148
# submit is a fire-and-forget operation, don't need to wait for response
147-
self.rpc_client.submit(request, __rpc_need_response=False)
149+
self.rpc_client.submit(request,
150+
__rpc_params=RPCParams(need_response=False))
148151

149152
result = GenerationResult(
150153
request,
@@ -157,16 +160,19 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
157160
return result
158161

159162
def fetch_responses_remote(self):
160-
return self.rpc_client.fetch_responses(__rpc_timeout=20)
163+
return self.rpc_client.fetch_responses(__rpc_params=RPCParams(
164+
timeout=20))
161165

162166
def fetch_stats_remote(self):
163167
return self.rpc_client.fetch_stats()
164168

165169
def setup_engine_remote(self):
166-
return self.rpc_client.setup_engine(__rpc_timeout=60 * 20) # 20 min
170+
return self.rpc_client.setup_engine(
171+
__rpc_params=RPCParams(timeout=60 * 20)) # 20 min
167172

168173
def shutdown_remote(self):
169-
self.rpc_client.shutdown(__rpc_timeout=60 * 20) # 20 min
174+
self.rpc_client.shutdown(__rpc_params=RPCParams(timeout=60 *
175+
20)) # 20 min
170176

171177
def abort_request(self, request_id: int) -> None:
172178
return self.rpc_client.abort_request(request_id)

tensorrt_llm/executor/rpc_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def fetch_responses(self) -> list:
5858
qsize = self._response_queue.qsize()
5959
return [self._response_queue.get() for _ in range(qsize)]
6060

61+
# for streaming performance
62+
async def fetch_responses_async(self) -> list:
63+
while not self.shutdown_event.is_set():
64+
responses = self.fetch_responses() # will block
65+
yield responses # batching the responses to opt IPC performance
66+
6167
def shutdown(self):
6268
logger.debug(f"RPC worker {mpi_rank()} is shutting down")
6369
self.shutdown_event.set()

tests/unittest/executor/test_rpc.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pytest
55

66
from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError,
7-
RPCServer, RPCStreamingError, RPCTimeout)
7+
RPCParams, RPCServer, RPCStreamingError,
8+
RPCTimeout)
89

910

1011
class RpcServerWrapper(RPCServer):
@@ -126,7 +127,7 @@ def get_task_submitted(self) -> bool:
126127
with RpcServerWrapper(App(),
127128
addr="ipc:///tmp/rpc_test_no_wait") as server:
128129
with RPCClient("ipc:///tmp/rpc_test_no_wait") as client:
129-
client.send_task(__rpc_need_response=False)
130+
client.send_task(__rpc_params=RPCParams(need_response=False))
130131
time.sleep(
131132
0.1
132133
) # wait for some time to make sure the task is submitted
@@ -209,7 +210,8 @@ def task(self):
209210
with RPCClient(addr) as client:
210211
client.shutdown_server()
211212
pending_futures = [
212-
client.task(__rpc_mode="future") for _ in range(10)
213+
client.task(__rpc_params=RPCParams(mode="future"))
214+
for _ in range(10)
213215
]
214216

215217
for future in pending_futures:
@@ -238,7 +240,7 @@ def slow_method(self):
238240
with RPCClient("ipc:///tmp/rpc_test_timeout",
239241
timeout=0.5) as client:
240242
with pytest.raises(RPCError) as exc_info:
241-
client.slow_method(__rpc_timeout=0.5)
243+
client.slow_method(__rpc_params=RPCParams(timeout=0.5))
242244

243245
error = exc_info.value
244246
# Should be either a timeout error or RPC error indicating timeout
@@ -306,14 +308,14 @@ def send_task(self) -> None:
306308
with RPCClient("ipc:///tmp/rpc_test_no_wait") as client:
307309
time_start = time.time()
308310
for i in range(100):
309-
client.send_task(__rpc_need_response=False)
311+
client.send_task(__rpc_params=RPCParams(need_response=False))
310312
time_end = time.time()
311313

312314
no_wait_time = time_end - time_start
313315

314316
time_start = time.time()
315317
for i in range(100):
316-
client.send_task(__rpc_need_response=True)
318+
client.send_task(__rpc_params=RPCParams(need_response=True))
317319
time_end = time.time()
318320
wait_time = time_end - time_start
319321

@@ -340,7 +342,8 @@ def cal(self, n: int):
340342

341343
time_start = time.time()
342344
for i in range(100):
343-
ret = client.cal(i, __rpc_timeout=10) # sync call
345+
ret = client.cal(
346+
i, __rpc_params=RPCParams(timeout=10)) # sync call
344347
assert ret == i * 2, f"{ret} != {i * 2}"
345348
time_end = time.time()
346349
print(
@@ -378,7 +381,7 @@ def teardown_method(self):
378381

379382
def run_sync_timeout_test(self):
380383
with pytest.raises(RPCTimeout) as exc_info:
381-
self.client.slow_operation(2.0, __rpc_timeout=0.1)
384+
self.client.slow_operation(2.0, __rpc_params=RPCParams(timeout=0.1))
382385
assert "timed out" in str(
383386
exc_info.value), f"Timeout message not found: {exc_info.value}"
384387

@@ -387,26 +390,25 @@ def run_async_timeout_test(self):
387390

388391
async def async_timeout():
389392
with pytest.raises(RPCTimeout) as exc_info:
390-
await self.client.call_async('slow_operation',
391-
2.0,
392-
__rpc_timeout=0.1)
393+
await self.client.call_async(
394+
'slow_operation', 2.0, __rpc_params=RPCParams(timeout=0.1))
393395
assert "timed out" in str(
394396
exc_info.value), f"Timeout message not found: {exc_info.value}"
395397

396398
asyncio.run(async_timeout())
397399

398400
def run_sync_success_test(self):
399-
result = self.client.slow_operation(0.1, __rpc_timeout=10.0)
401+
result = self.client.slow_operation(
402+
0.1, __rpc_params=RPCParams(timeout=10.0))
400403
assert result == "completed"
401404
print(f"final result: {result}")
402405

403406
def run_async_success_test(self):
404407
import asyncio
405408

406409
async def async_success():
407-
result = await self.client.call_async('slow_operation',
408-
0.1,
409-
__rpc_timeout=10.0)
410+
result = await self.client.call_async(
411+
'slow_operation', 0.1, __rpc_params=RPCParams(timeout=10.0))
410412
assert result == "completed"
411413
print(f"final result: {result}")
412414
return result
@@ -457,7 +459,8 @@ def foo(self, delay: int):
457459
time.sleep(0.1)
458460
with RPCClient("ipc:///tmp/rpc_test_shutdown") as client:
459461
# This task should be continued after server shutdown
460-
res = client.foo(10, __rpc_timeout=12, __rpc_mode="future")
462+
res = client.foo(10,
463+
__rpc_params=RPCParams(timeout=12, mode="future"))
461464

462465
# The shutdown will block until all pending requests are finished
463466
server.shutdown()
@@ -589,7 +592,7 @@ async def test_streaming_timeout(self):
589592
# Set short timeout
590593
with pytest.raises(RPCTimeout):
591594
async for value in client.streaming_timeout.call_streaming(
592-
delay=2.0, __rpc_timeout=0.5):
595+
delay=2.0, __rpc_params=RPCParams(timeout=0.5)):
593596
pass # Should timeout before first yield
594597

595598
@pytest.mark.asyncio

tests/unittest/executor/test_rpc_worker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from test_worker_base import create_fake_executor_config
88

99
from tensorrt_llm.executor.request import GenerationRequest
10-
from tensorrt_llm.executor.rpc import RPCClient
10+
from tensorrt_llm.executor.rpc import RPCClient, RPCParams
1111
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
1212
from tensorrt_llm.executor.rpc_worker import RpcWorker
1313
from tensorrt_llm.sampling_params import SamplingParams
@@ -43,14 +43,14 @@ def create_rpc_client(self, addr: str):
4343
def test_main_loop(self):
4444
pool, addr = self.create_tp1_worker_process()
4545
client = self.create_rpc_client(addr)
46-
client.setup_engine(__rpc_timeout=120)
46+
client.setup_engine(__rpc_params=RPCParams(timeout=120))
4747
time.sleep(1)
4848

4949
def process_request():
5050
ret = client.submit(GenerationRequest(
5151
prompt_token_ids=[3, 4, 5],
5252
sampling_params=SamplingParams(max_tokens=10)),
53-
__rpc_need_response=False)
53+
__rpc_params=RPCParams(need_response=False))
5454
assert ret is None # need_response = False
5555

5656
print(f"submit result: {ret}")
@@ -69,7 +69,7 @@ def process_request_streaming():
6969
prompt_token_ids=[3, 4, 5],
7070
sampling_params=SamplingParams(max_tokens=10),
7171
streaming=True),
72-
__rpc_need_response=False)
72+
__rpc_params=RPCParams(need_response=False))
7373
assert ret is None
7474
print("submit result: ", ret)
7575

@@ -80,7 +80,8 @@ def process_request_streaming():
8080

8181
while not results:
8282
time.sleep(1)
83-
results.extend(client.fetch_responses(__rpc_timeout=10))
83+
results.extend(
84+
client.fetch_responses(__rpc_params=RPCParams(timeout=10)))
8485
print(f"try fetch_responses result: {results}")
8586
print(f"fetch_responses result: {results}")
8687
assert results
@@ -90,7 +91,7 @@ def process_request_streaming():
9091
process_request_streaming()
9192

9293
print("call shutdown")
93-
client.shutdown(__rpc_timeout=10)
94+
client.shutdown(__rpc_params=RPCParams(timeout=10))
9495
pool.shutdown()
9596
client.close()
9697

0 commit comments

Comments
 (0)