Skip to content

Commit 34aa759

Browse files
committed
refactor RPC param
Signed-off-by: Superjomn <[email protected]>
1 parent 9995d02 commit 34aa759

File tree

9 files changed

+185
-48
lines changed

9 files changed

+185
-48
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import time
2+
from typing import Optional, Union
3+
4+
from tensorrt_llm.logger import logger
5+
6+
from ..llmapi.utils import AsyncQueue, _SyncQueue
7+
from .executor import GenerationExecutor
8+
from .ipc import FusedIpcQueue
9+
from .result import IterationResult
10+
from .utils import IntraProcessQueue
11+
12+
13+
class ProxyBase(GenerationExecutor):
14+
15+
def __init__(self,
16+
num_postprocess_workers: int = 0,
17+
postprocess_tokenizer_dir: Optional[str] = None,
18+
is_llm_executor: Optional[bool] = None):
19+
super().__init__(num_postprocess_workers, postprocess_tokenizer_dir,
20+
is_llm_executor)
21+
22+
def _maybe_initialize_iteration_results(self):
23+
if self._is_llm_executor:
24+
if self._iter_stats_result is None:
25+
# singleton to store cpp runtime stats
26+
self._iter_stats_result = IterationResult()
27+
else:
28+
# expect more engine stats whenever new prompts are submitted
29+
self._iter_stats_result.mark_undone()
30+
31+
if self._iter_kv_events_result is None:
32+
self._iter_kv_events_result = IterationResult()
33+
else:
34+
self._iter_kv_events_result.mark_undone()
35+
36+
def _iteration_result_task(self, queue: Union[FusedIpcQueue,
37+
IntraProcessQueue],
38+
result_singleton: IterationResult) -> bool:
39+
# iteration result is not urgent, so we can sleep a bit
40+
time.sleep(0.2)
41+
42+
try:
43+
data = queue.get()
44+
except:
45+
logger.debug(
46+
"proxy.py: Error in _iteration_result_task: queue.get()")
47+
return False
48+
49+
if data is None:
50+
logger.debug("proxy.py: _iteration_result_task: data is None")
51+
return False # shutdown the thread
52+
53+
data = data if isinstance(data, list) else [data]
54+
queue = result_singleton.queue
55+
async_queues = []
56+
57+
while queue.full():
58+
queue.get()
59+
60+
try:
61+
for d in data:
62+
if d is None:
63+
logger.debug("proxy.py: _iteration_result_task: d is None")
64+
return False
65+
66+
if isinstance(queue, _SyncQueue):
67+
queue.put_nowait(d)
68+
async_queues.append(queue)
69+
else:
70+
queue.put(d)
71+
72+
if async_queues:
73+
_SyncQueue.notify_many(queue.loop, async_queues)
74+
75+
except AsyncQueue.EventLoopShutdownError:
76+
# This happens in the last loop while the generate workflow is
77+
# stopped, or when get_stats() or aget_stats() are not called by users
78+
# and therefore event loop can already be closed.
79+
logger.debug("proxy.py: EventLoopShutdownError")
80+
except Exception as e:
81+
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
82+
raise e
83+
84+
return True # success
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .rpc_client import RPCClient
2-
from .rpc_common import (RPCCancelled, RPCError, RPCRequest, RPCResponse,
3-
RPCStreamingError, RPCTimeout)
2+
from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest,
3+
RPCResponse, RPCStreamingError, RPCTimeout)
44
from .rpc_server import RPCServer, Server
55

66
__all__ = [
77
"RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout",
8-
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse"
8+
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse",
9+
"RPCParams"
910
]

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ...logger import logger
88
from ..ipc import ZeroMqQueue
9-
from .rpc_common import (RPCCancelled, RPCRequest, RPCResponse,
9+
from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse,
1010
RPCStreamingError, RPCTimeout)
1111

1212

@@ -164,33 +164,32 @@ def _start_response_reader_lazily(self):
164164
# Store the concurrent.futures.Future
165165
self._reader_task = future
166166

167-
async def _call_async(self, __rpc_method_name, *args, **kwargs):
167+
async def _call_async(self, method_name, *args, **kwargs):
168168
"""Async version of RPC call.
169169
Args:
170-
__rpc_method_name: Method name to call
170+
method_name: Method name to call
171171
*args: Positional arguments
172172
**kwargs: Keyword arguments
173-
__rpc_timeout: The timeout (seconds) for the RPC call.
174-
__rpc_need_response: Whether the RPC call needs a response.
175-
If set to False, the remote call will return immediately.
173+
__rpc_params: RPCParams object containing RPC parameters.
176174
177175
Returns:
178176
The result of the remote method call
179177
"""
180178
logger.debug(
181-
f"RPC client calling method: {__rpc_method_name} with args: {args} and kwargs: {kwargs}"
179+
f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
182180
)
183181
if self._server_stopped:
184182
raise RPCCancelled("Server is shutting down, request cancelled")
185183

186184
self._start_response_reader_lazily()
187-
need_response = kwargs.pop("__rpc_need_response", True)
188-
timeout = kwargs.pop("__rpc_timeout", self._timeout)
185+
rpc_params = kwargs.pop("__rpc_params", RPCParams())
186+
need_response = rpc_params.need_response
187+
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
189188

190189
request_id = uuid.uuid4().hex
191190
logger.debug(f"RPC client sending request: {request_id}")
192191
request = RPCRequest(request_id,
193-
__rpc_method_name,
192+
method_name,
194193
args,
195194
kwargs,
196195
need_response,
@@ -216,7 +215,7 @@ async def _call_async(self, __rpc_method_name, *args, **kwargs):
216215
raise
217216
except asyncio.TimeoutError:
218217
raise RPCTimeout(
219-
f"Request '{__rpc_method_name}' timed out after {timeout}s")
218+
f"Request '{method_name}' timed out after {timeout}s")
220219
except Exception as e:
221220
raise e
222221
finally:
@@ -241,11 +240,11 @@ def run_loop():
241240
import time
242241
time.sleep(0.1)
243242

244-
def _call_sync(self, __rpc_method_name, *args, **kwargs):
243+
def _call_sync(self, method_name, *args, **kwargs):
245244
"""Synchronous version of RPC call."""
246245
self._ensure_event_loop()
247246
future = asyncio.run_coroutine_threadsafe(
248-
self._call_async(__rpc_method_name, *args, **kwargs), self._loop)
247+
self._call_async(method_name, *args, **kwargs), self._loop)
249248
return future.result()
250249

251250
def call_async(self, name: str, *args, **kwargs):
@@ -263,7 +262,9 @@ def call_async(self, name: str, *args, **kwargs):
263262
Example:
264263
result = await client.call_async('remote_method', arg1, arg2, key=value)
265264
"""
266-
return self._call_async(name, *args, **kwargs, __rpc_need_response=True)
265+
if "__rpc_params" not in kwargs:
266+
kwargs["__rpc_params"] = RPCParams(need_response=True)
267+
return self._call_async(name, *args, **kwargs)
267268

268269
def call_future(self, name: str, *args,
269270
**kwargs) -> concurrent.futures.Future:
@@ -331,7 +332,8 @@ async def call_streaming(self, name: str, *args,
331332
raise RPCCancelled("Server is shutting down, request cancelled")
332333

333334
self._start_response_reader_lazily()
334-
timeout = kwargs.pop("__rpc_timeout", self._timeout)
335+
rpc_params = kwargs.pop("__rpc_params", RPCParams())
336+
timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout
335337

336338
request_id = uuid.uuid4().hex
337339
queue = asyncio.Queue()
@@ -379,7 +381,9 @@ async def call_streaming(self, name: str, *args,
379381
def get_server_attr(self, name: str):
380382
""" Get the attribute of the RPC server.
381383
This is mainly used for testing. """
382-
return self._call_sync("__rpc_get_attr", name, __rpc_timeout=10)
384+
return self._call_sync("__rpc_get_attr",
385+
name,
386+
__rpc_params=RPCParams(timeout=10))
383387

384388
def __getattr__(self, name):
385389
"""
@@ -395,7 +399,8 @@ def __init__(self, client, method_name):
395399

396400
def __call__(self, *args, **kwargs):
397401
"""Default synchronous call"""
398-
mode = kwargs.pop("__rpc_mode", "sync")
402+
rpc_params = kwargs.get("__rpc_params", RPCParams())
403+
mode = rpc_params.mode
399404
if mode == "sync":
400405
return self.client._call_sync(self.method_name, *args,
401406
**kwargs)

tensorrt_llm/executor/rpc/rpc_common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
from typing import Any, Literal, NamedTuple, Optional
22

33

4+
class RPCParams(NamedTuple):
5+
""" Parameters for RPC calls. """
6+
7+
# seconds to wait for the response
8+
timeout: Optional[float] = None
9+
10+
# whether the client needs the response, if False, it will return immediately
11+
need_response: bool = True
12+
13+
# mode for RPC calls: "sync", "async", or "future"
14+
mode: str = "sync"
15+
16+
417
# --- Custom Exceptions ---
518
class RPCError(Exception):
619
"""Custom exception for RPC-related errors raised on the client side.

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .request import GenerationRequest
1515
from .result import GenerationResult
1616
from .rpc import RPCClient
17+
from .rpc.rpc_common import RPCParams
1718
from .rpc_worker import RpcWorker
1819
from .utils import (ErrorResponse, create_mpi_comm_session,
1920
get_spawn_proxy_process_env, is_llm_response)
@@ -90,7 +91,8 @@ def main_loop_task(self):
9091
clock = 0
9192
while not self._shutdown_event.is_set():
9293
if clock % 1 == 0:
93-
responses = self.fetch_responses_remote()
94+
responses = self.fetch_responses_remote(
95+
) # RPC request => RPC server; result => RPC client
9496
self.handle_responses(responses)
9597
if clock % 10 == 0:
9698
stats = self.fetch_stats_remote() # TODO
@@ -144,7 +146,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
144146
logprob_params = self._get_logprob_params(request)
145147

146148
# submit is a fire-and-forget operation, don't need to wait for response
147-
self.rpc_client.submit(request, __rpc_need_response=False)
149+
self.rpc_client.submit(request,
150+
__rpc_params=RPCParams(need_response=False))
148151

149152
result = GenerationResult(
150153
request,
@@ -157,16 +160,19 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
157160
return result
158161

159162
def fetch_responses_remote(self):
160-
return self.rpc_client.fetch_responses(__rpc_timeout=20)
163+
return self.rpc_client.fetch_responses(__rpc_params=RPCParams(
164+
timeout=20))
161165

162166
def fetch_stats_remote(self):
163167
return self.rpc_client.fetch_stats()
164168

165169
def setup_engine_remote(self):
166-
return self.rpc_client.setup_engine(__rpc_timeout=60 * 20) # 20 min
170+
return self.rpc_client.setup_engine(
171+
__rpc_params=RPCParams(timeout=60 * 20)) # 20 min
167172

168173
def shutdown_remote(self):
169-
self.rpc_client.shutdown(__rpc_timeout=60 * 20) # 20 min
174+
self.rpc_client.shutdown(__rpc_params=RPCParams(timeout=60 *
175+
20)) # 20 min
170176

171177
def abort_request(self, request_id: int) -> None:
172178
return self.rpc_client.abort_request(request_id)

tensorrt_llm/executor/rpc_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def fetch_responses(self) -> list:
5858
qsize = self._response_queue.qsize()
5959
return [self._response_queue.get() for _ in range(qsize)]
6060

61+
# for streaming performance
62+
async def fetch_responses_async(self) -> list:
63+
while not self.shutdown_event.is_set():
64+
responses = self.fetch_responses() # will block
65+
yield responses # batching the responses to opt IPC performance
66+
6167
def shutdown(self):
6268
logger.debug(f"RPC worker {mpi_rank()} is shutting down")
6369
self.shutdown_event.set()

0 commit comments

Comments
 (0)