|
4 | 4 | from pathlib import Path |
5 | 5 | from queue import Queue |
6 | 6 | from threading import Event |
7 | | -from typing import Any, Optional, Type, Union |
| 7 | +from typing import Any, AsyncGenerator, Optional, Type, Union |
8 | 8 |
|
9 | 9 | import ray |
10 | 10 | import torch |
11 | 11 |
|
12 | | -from .._utils import mpi_rank, ray_use_rpc |
| 12 | +from .._utils import nvtx_range_debug, ray_use_rpc |
13 | 13 | from ..bindings import executor as tllm |
14 | 14 | from ..builder import Engine |
15 | 15 | from ..llmapi.llm_args import BaseLlmArgs |
@@ -236,53 +236,63 @@ def enqueue_request(self, |
236 | 236 | request: GenerationRequest, |
237 | 237 | result_wait_queue: Queue | None = None) -> int: |
238 | 238 | # TODO. remove this. originally we didn't have to handle all the req id dict |
| 239 | + # raise ValueError("enqueue_request should not be called.") |
239 | 240 | return self._enqueue_request(request, result_wait_queue) |
240 | 241 |
|
| 242 | + def start(self): |
| 243 | + pass |
| 244 | + |
241 | 245 | def submit(self, request: GenerationRequest): |
242 | | - print(f"RayGPUWorker {self.rank} submitted request {request.id}") |
243 | 246 | return super().submit(request) |
244 | 247 |
|
245 | | - async def fetch_responses_async(self, |
246 | | - timeout: Optional[float] = None) -> list: |
| 248 | + def fetch_responses(self, timeout: Optional[float] = None) -> list: |
247 | 249 | # TODO copied from RpcWorker, need refactoring. |
248 | | - logger_debug(f"RayGPUWorker {mpi_rank()} is fetching responses async", |
| 250 | + logger_debug(f"RayGPUWorker {self.rank} is fetching responses", |
249 | 251 | color="yellow") |
250 | | - |
251 | | - responses = await asyncio.to_thread(self.await_responses, |
252 | | - timeout=timeout) |
253 | | - if self._await_response_helper: |
| 252 | + with nvtx_range_debug("RayGPUWorker.fetch_responses", |
| 253 | + color="orange", |
| 254 | + category="Worker"): |
| 255 | + # NOTE: This is a blocking call, it will wait for the responses to be available. |
| 256 | + responses = super().await_responses(timeout) |
254 | 257 | self._await_response_helper.responses_handler(responses) |
255 | 258 |
|
256 | | - if hasattr(self, |
257 | | - '_response_queue') and self._response_queue is not None: |
258 | | - qsize = self._response_queue.qsize() |
259 | | - logger_debug(f"RayGPUWorker returning {qsize} responses", |
260 | | - color="yellow") |
| 259 | + qsize = self._response_queue.qsize() |
| 260 | + logger_debug(f"RayGPUWorker returning {qsize} responses", |
| 261 | + color="yellow") |
| 262 | + |
| 263 | + all_responses = [] |
| 264 | + for _ in range(qsize): |
| 265 | + # The queue contains batches of responses, so extend the list |
| 266 | + all_responses.extend(self._response_queue.get()) |
| 267 | + return all_responses |
261 | 268 |
|
262 | | - all_responses = [] |
263 | | - for _ in range(qsize): |
264 | | - all_responses.extend(self._response_queue.get()) |
265 | | - return all_responses |
| 269 | + async def fetch_responses_async(self, |
| 270 | + timeout: Optional[float] = None) -> list: |
| 271 | + # TODO copied from RpcWorker, need refactoring. |
| 272 | + # A really async version of fetch_responses |
| 273 | + logger_debug(f"RayGPUWorker {self.rank} is fetching responses async", |
| 274 | + color="yellow") |
266 | 275 |
|
267 | | - return responses if responses else [] |
| 276 | + # First, await any pending responses without blocking the event loop |
| 277 | + responses = await asyncio.to_thread(self.fetch_responses, |
| 278 | + timeout=timeout) |
| 279 | + return responses |
268 | 280 |
|
269 | 281 | # for streaming performance |
270 | | - async def fetch_responses_loop_async(self): |
| 282 | + async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: |
271 | 283 | # TODO copied from RpcWorker, need refactoring. |
272 | | - shutdown_event = getattr(self, 'shutdown_event', Event()) |
273 | | - |
274 | | - while not shutdown_event.is_set(): |
| 284 | + while not self.shutdown_event.is_set(): |
275 | 285 | responses = await self.fetch_responses_async() |
276 | | - if responses: |
| 286 | + if responses: # Only yield if there are actual responses |
277 | 287 | logger_debug( |
278 | | - f"RayGPUWorker {mpi_rank()} yielding responses: {responses}", |
| 288 | + f"RayGPUWorker {self.rank} is yielding responses: {responses}", |
279 | 289 | color="yellow") |
280 | | - yield responses |
| 290 | + yield responses # batching the responses to opt IPC performance |
281 | 291 | else: |
| 292 | + # Small delay to prevent busy waiting when no responses |
282 | 293 | await asyncio.sleep(0) |
283 | | - |
284 | 294 | logger_debug( |
285 | | - f"RayGPUWorker {mpi_rank()} quitting fetch_responses_loop_async", |
| 295 | + f"RayGPUWorker {self.rank} quitting fetch_responses_loop_async", |
286 | 296 | color="yellow") |
287 | 297 |
|
288 | 298 | def shutdown(self): |
|
0 commit comments