|
1 | | -import asyncio |
2 | | -import atexit |
3 | 1 | import os |
4 | | -import threading |
5 | 2 | from typing import Any, Dict, List, Optional, Tuple |
6 | 3 |
|
7 | 4 | try: |
|
19 | 16 | from tensorrt_llm._utils import get_free_port |
20 | 17 | from tensorrt_llm.logger import logger |
21 | 18 |
|
22 | | -from .._utils import nvtx_range_debug |
23 | | -from ..llmapi.tracer import global_tracer |
24 | | -from ..llmapi.utils import _SyncQueue, logger_debug |
| 19 | +from ..llmapi.utils import logger_debug |
25 | 20 | from .executor import GenerationExecutor |
26 | 21 | from .postproc_worker import PostprocWorkerConfig |
27 | 22 | from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper |
28 | | -from .request import GenerationRequest |
29 | | -from .result import GenerationResult |
30 | | -from .rpc import RPCClient |
31 | | -from .rpc.rpc_common import get_unique_ipc_addr |
32 | | -from .utils import ErrorResponse, is_llm_response |
| 23 | +from .rpc_proxy import RpcExecutorMixin |
33 | 24 |
|
34 | 25 | __all__ = [ |
35 | 26 | "RayExecutor", |
36 | 27 | ] |
37 | 28 |
|
38 | 29 |
|
39 | | -class RayExecutor(GenerationExecutor): |
| 30 | +class RayExecutor(RpcExecutorMixin, GenerationExecutor): |
40 | 31 |
|
41 | 32 | def __init__(self, |
42 | 33 | worker_kwargs: Dict, |
@@ -83,135 +74,22 @@ def __init__(self, |
83 | 74 | self.tp_size = tp_size |
84 | 75 | self.master_address = ray.util.get_node_ip_address() |
85 | 76 | self.master_port = get_free_port() |
86 | | - |
87 | | - self.rpc_addr = get_unique_ipc_addr() |
88 | | - self.rpc_client = RPCClient(self.rpc_addr) |
89 | | - |
90 | | - self._results = {} |
91 | | - self._shutdown_event = threading.Event() |
92 | | - self.main_loop_task_obj = None |
93 | | - self.main_loop = None |
| 77 | + self.init_rpc_executor() |
94 | 78 |
|
95 | 79 | worker_kwargs = dict(**worker_kwargs, |
96 | 80 | postproc_worker_config=postproc_worker_config, |
97 | 81 | is_llm_executor=is_llm_executor, |
98 | 82 | rpc_addr=self.rpc_addr) |
99 | | - |
100 | 83 | self.create_workers(RayGPUWorker, worker_kwargs) |
101 | | - |
102 | | - logger.info("Setting up engine via RPC") |
103 | 84 | self.setup_engine_remote() |
104 | | - self.setup_mainloop() |
| 85 | + self.setup_mainloop(tasks=[self._fetch_responses_loop_async], |
| 86 | + thread_name="ray_executor_main_loop") |
105 | 87 | except Exception as e: |
106 | 88 | # Clean up the Ray resources early during exception |
107 | 89 | self.shutdown() |
108 | 90 | logger.error(f"Failed to initialize RayExecutor: {e}") |
109 | 91 | raise e |
110 | 92 |
|
111 | | - @staticmethod |
112 | | - def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle): |
113 | | - state, _, _ = actor_handle._serialization_helper() |
114 | | - return ray.actor.ActorHandle._deserialization_helper(state, |
115 | | - weak_ref=True) |
116 | | - |
117 | | - async def _generic_fetch_loop_async(self, fetch_method_name: str, |
118 | | - handler_method, method_name: str): |
119 | | - # TODO copied from GenerationExecutorRpcProxy, need refactoring. |
120 | | - """Generic method for fetching data in a loop from RPC worker. |
121 | | -
|
122 | | - Args: |
123 | | - fetch_method_name: Name of the RPC client method to call |
124 | | - handler_method: The handler method to call with the fetched data |
125 | | - method_name: Name of the method for logging |
126 | | - """ |
127 | | - try: |
128 | | - fetch_method = getattr(self.rpc_client, fetch_method_name) |
129 | | - async for data in fetch_method().remote_streaming(): |
130 | | - if self._shutdown_event.is_set(): |
131 | | - return |
132 | | - handler_method(data) |
133 | | - except asyncio.CancelledError: |
134 | | - logger.debug(f"{method_name} task cancelled") |
135 | | - except Exception as e: |
136 | | - logger.error(f"Error in {method_name}: {e}") |
137 | | - raise |
138 | | - |
139 | | - async def _fetch_responses_loop_async(self): |
140 | | - # TODO copied from GenerationExecutorRpcProxy, need refactoring. |
141 | | - await self._generic_fetch_loop_async( |
142 | | - fetch_method_name="fetch_responses_loop_async", |
143 | | - handler_method=self.handle_responses, |
144 | | - method_name="_fetch_responses_loop_async") |
145 | | - |
146 | | - def setup_mainloop(self): |
147 | | - # TODO copied from GenerationExecutorRpcProxy, need refactoring. |
148 | | - async def main_loop_task(): |
149 | | - await self._fetch_responses_loop_async() |
150 | | - |
151 | | - def _run_main_loop_task(): |
152 | | - """Local method to run the main loop task.""" |
153 | | - self.main_loop = asyncio.new_event_loop() |
154 | | - asyncio.set_event_loop(self.main_loop) |
155 | | - |
156 | | - self.main_loop_task_obj = self.main_loop.create_task( |
157 | | - main_loop_task()) |
158 | | - try: |
159 | | - self.main_loop.run_until_complete(self.main_loop_task_obj) |
160 | | - except asyncio.CancelledError: |
161 | | - pass # Task cancellation is expected during shutdown |
162 | | - finally: |
163 | | - self.main_loop.close() |
164 | | - |
165 | | - self.main_loop_thread = threading.Thread(target=_run_main_loop_task, |
166 | | - daemon=True, |
167 | | - name="ray_executor_main_loop") |
168 | | - self.main_loop_thread.start() |
169 | | - atexit.register(self.shutdown) |
170 | | - |
171 | | - def setup_engine_remote(self): |
172 | | - return self.collective_rpc("setup_engine", non_block=False) |
173 | | - |
174 | | - def handle_responses(self, responses: list[GenerationResult]) -> bool: |
175 | | - # TODO copied from GenerationExecutorRpcProxy, need refactoring. |
176 | | - async_queues = [] |
177 | | - event_loop = None |
178 | | - |
179 | | - def process_res(res: list): |
180 | | - for r in res: |
181 | | - client_id = r.client_id |
182 | | - nonlocal event_loop |
183 | | - nonlocal async_queues |
184 | | - |
185 | | - if client_id not in self._results: |
186 | | - logger.warning( |
187 | | - f"Received response for unknown client_id: {client_id}") |
188 | | - continue |
189 | | - |
190 | | - queue = self._results[client_id].queue |
191 | | - if isinstance(queue, _SyncQueue): |
192 | | - queue.put_nowait(r) |
193 | | - async_queues.append(queue) |
194 | | - # all the loops are identical |
195 | | - event_loop = event_loop or queue.loop |
196 | | - else: |
197 | | - queue.put(r) |
198 | | - |
199 | | - if (is_llm_response(r) and r.result.is_final) or isinstance( |
200 | | - r, ErrorResponse): |
201 | | - self._results.pop(client_id) |
202 | | - |
203 | | - # Handle the case where responses might not be a list of lists |
204 | | - if responses and not isinstance(responses[0], list): |
205 | | - # If responses is a flat list, wrap it |
206 | | - responses = [responses] |
207 | | - |
208 | | - for res in responses: |
209 | | - global_tracer().log_instant("RPC.get") |
210 | | - process_res(res) |
211 | | - |
212 | | - if async_queues: |
213 | | - _SyncQueue.notify_many(event_loop, async_queues) |
214 | | - |
215 | 93 | def create_workers(self, worker_cls, worker_kwargs): |
216 | 94 | # When set to be a fraction, it allows Ray to schedule |
217 | 95 | # multiple actors on a single GPU for colocate use cases. |
@@ -287,31 +165,12 @@ def collective_rpc(self, |
287 | 165 | **kwargs)) |
288 | 166 | return refs if non_block else ray.get(refs) |
289 | 167 |
|
290 | | - def submit(self, request: GenerationRequest) -> GenerationResult: |
291 | | - """ |
292 | | - Low-level API to the executor. Return a "future" GenerationResult |
293 | | - which can be waited. |
294 | | - Forwards the request to the workers through RPC. |
295 | | - """ |
296 | | - request.set_id(self._get_next_client_id()) |
297 | | - logprob_params = self._get_logprob_params(request) |
298 | | - |
299 | | - with nvtx_range_debug("rpc_submit"): |
300 | | - self.rpc_client.submit(request).remote(need_response=False) |
301 | | - |
302 | | - result = GenerationResult( |
303 | | - request, |
304 | | - background_error_handler=self._handle_background_error, |
305 | | - executor=self, |
306 | | - disaggregated_params=request.disaggregated_params, |
307 | | - logprob_params=logprob_params) |
308 | | - self._results[request.id] = result |
309 | | - |
310 | | - return result |
311 | | - |
312 | 168 | def start(self): |
313 | 169 | pass |
314 | 170 |
|
| 171 | + def setup_engine_remote(self): |
| 172 | + return self.collective_rpc("setup_engine", non_block=False) |
| 173 | + |
315 | 174 | def report_device_ids(self) -> list[str]: |
316 | 175 | gpu_ids = self.call_all_ray_workers("report_device_id", |
317 | 176 | leader_only=False, |
|
0 commit comments