Skip to content

Commit 109d48e

Browse files
authored
[Feature] support async download features (#5003)
* support async download features * add test case * update code
1 parent bde97e0 commit 109d48e

File tree

10 files changed

+433
-75
lines changed

10 files changed

+433
-75
lines changed

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ def __init__(
550550
self.use_internode_ll_two_stage: bool = False
551551
# disable sequence parallel moe
552552
self.disable_sequence_parallel_moe: bool = False
553+
# enable async download features
554+
self.enable_async_download_features: bool = False
553555

554556
self.pod_ip: str = None
555557
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).

fastdeploy/engine/args_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,11 @@ class EngineArgs:
467467
Url for router server, such as `0.0.0.0:30000`.
468468
"""
469469

470+
enable_async_download_features: bool = False
471+
"""
472+
Flag to enable async download features. Default is False (disabled).
473+
"""
474+
470475
def __post_init__(self):
471476
"""
472477
Post-initialization processing to set default tokenizer if not provided.
@@ -849,6 +854,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
849854
default=EngineArgs.enable_expert_parallel,
850855
help="Enable expert parallelism.",
851856
)
857+
parallel_group.add_argument(
858+
"--enable-async-download-features",
859+
action="store_true",
860+
default=EngineArgs.enable_async_download_features,
861+
help="Enable async download features.",
862+
)
852863

853864
# Load group
854865
load_group = parser.add_argument_group("Load Configuration")

fastdeploy/engine/common_engine.py

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,7 @@
5151
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
5252
from fastdeploy.trace.constants import LoggingEventName
5353
from fastdeploy.trace.trace_logger import print as trace_print
54-
from fastdeploy.utils import (
55-
EngineError,
56-
check_download_links,
57-
envs,
58-
get_logger,
59-
init_bos_client,
60-
llm_logger,
61-
)
54+
from fastdeploy.utils import EngineError, envs, get_logger, llm_logger
6255

6356
try:
6457
TokenProcessor = load_token_processor_plugins()
@@ -808,7 +801,7 @@ def _fetch_request():
808801
else:
809802
raise
810803
# 2. Schedule requests
811-
tasks = self.resource_manager.schedule()
804+
tasks, error_tasks = self.resource_manager.schedule()
812805

813806
# 3. Send to engine
814807
if tasks:
@@ -833,7 +826,16 @@ def _fetch_request():
833826
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
834827
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
835828
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
836-
else:
829+
830+
# 4. Response error tasks
831+
if error_tasks:
832+
for request_id, failed in error_tasks:
833+
if failed is None:
834+
llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
835+
continue
836+
self._send_error_response(request_id, failed)
837+
838+
if not tasks and not error_tasks:
837839
time.sleep(0.005)
838840

839841
except RuntimeError as e:
@@ -909,24 +911,6 @@ def _insert_zmq_task_to_scheduler(self):
909911
self.llm_logger.error(f"Receive request error: {err_msg}")
910912
results.append((request.request_id, err_msg))
911913

912-
if self._has_features_info(request) and err_msg is None:
913-
if self.bos_client is None:
914-
self.bos_client = init_bos_client()
915-
916-
download_urls = []
917-
inputs = request.multimodal_inputs
918-
if inputs.get("video_feature_urls") is not None:
919-
download_urls.extend(inputs.get("video_feature_urls"))
920-
if inputs.get("image_feature_urls") is not None:
921-
download_urls.extend(inputs.get("image_feature_urls"))
922-
if inputs.get("audio_feature_urls") is not None:
923-
download_urls.extend(inputs.get("audio_feature_urls"))
924-
925-
err_msg = check_download_links(self.bos_client, download_urls)
926-
if err_msg:
927-
llm_logger.error(f"Receive request {request.request_id} download error: {err_msg}")
928-
results.append((request.request_id, err_msg))
929-
930914
if err_msg is None:
931915
insert_task.append(request)
932916

@@ -948,21 +932,27 @@ def _insert_zmq_task_to_scheduler(self):
948932
main_process_metrics.num_requests_waiting.inc(1)
949933
continue
950934

951-
error_result = RequestOutput(
952-
request_id=request_id,
953-
finished=True,
954-
error_code=500,
955-
error_msg=failed,
956-
)
957-
# Since the request is not in scheduler
958-
# Send result by zmq directly
959-
self.send_response_server.send_response(request_id, [error_result])
935+
self._send_error_response(request_id, failed)
960936
except Exception as e:
961937
self.llm_logger.error(
962938
f"Error happened while receiving new request from zmq, details={e}, "
963939
f"traceback={traceback.format_exc()}"
964940
)
965941

942+
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
943+
llm_logger.error(
944+
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
945+
)
946+
error_result = RequestOutput(
947+
request_id=request_id,
948+
finished=True,
949+
error_code=error_code,
950+
error_msg=error_msg,
951+
)
952+
# Since the request is not in scheduler
953+
# Send result by zmq directly
954+
self.send_response_server.send_response(request_id, [error_result])
955+
966956
def _decode_token(self, token_ids, req_id, is_end):
967957
delta_text = ""
968958
if envs.FD_ENABLE_RETURN_TEXT:
@@ -977,19 +967,6 @@ def _decode_token(self, token_ids, req_id, is_end):
977967
del self.data_processor.decode_status[req_id]
978968
return delta_text, token_ids
979969

980-
def _has_features_info(self, task):
981-
inputs = task.multimodal_inputs
982-
if inputs is None or len(inputs) == 0:
983-
return False
984-
985-
if (
986-
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
987-
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
988-
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
989-
):
990-
return True
991-
return False
992-
993970
def _zmq_send_generated_tokens(self):
994971
"""
995972
Recieve output for zmq

fastdeploy/engine/request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def __init__(
173173
# dp
174174
self.dp_rank = dp_rank
175175

176+
self.async_process_futures = []
177+
self.error_message = None
178+
self.error_code = None
179+
176180
@classmethod
177181
def from_dict(cls, d: dict):
178182
data_processor_logger.debug(f"{d}")

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from fastdeploy.metrics.metrics import main_process_metrics
4545
from fastdeploy.multimodal.hasher import MultimodalHasher
4646
from fastdeploy.platforms import current_platform
47-
from fastdeploy.utils import llm_logger
47+
from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger
4848

4949

5050
@dataclass
@@ -195,6 +195,9 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
195195
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
196196
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
197197

198+
self.bos_client = None
199+
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)
200+
198201
def allocated_slots(self, request: Request):
199202
return len(request.block_tables) * self.config.cache_config.block_size
200203

@@ -500,6 +503,7 @@ def schedule(self):
500503
with self.lock:
501504
scheduled_reqs: list[Request] = []
502505
preempted_reqs: list[Request] = []
506+
error_reqs: list[tuple[str, str]] = []
503507
token_budget = self.config.scheduler_config.max_num_batched_tokens
504508

505509
# First, schedule the RUNNING requests.
@@ -629,6 +633,7 @@ def _allocate_decode_and_extend():
629633
req_index += 1
630634
# schedule the WAITING requests.
631635
if not preempted_reqs:
636+
skip_requests: list[Request] = []
632637
while self.waiting and token_budget > 0:
633638
if len(self.running) == self.max_num_seqs:
634639
break
@@ -639,6 +644,17 @@ def _allocate_decode_and_extend():
639644
):
640645
break
641646
if request.status == RequestStatus.WAITING:
647+
result = self._waiting_async_process(request)
648+
if result is None:
649+
error_reqs.append((request.request_id, request.error_message))
650+
self.waiting.popleft()
651+
continue
652+
elif result is True:
653+
# skip current request, try next request
654+
skip_requests.append(request)
655+
self.waiting.popleft()
656+
continue
657+
642658
self._update_mm_hashes(request)
643659
# Enable prefix caching
644660
if self.config.cache_config.enable_prefix_caching:
@@ -725,12 +741,102 @@ def _allocate_decode_and_extend():
725741
else:
726742
llm_logger.error("Unknown request status type")
727743

744+
for req in skip_requests:
745+
# move waiting request to end of the deque
746+
self.waiting.append(req)
747+
728748
if scheduled_reqs:
729749
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
730750

731751
self.update_metrics()
732752

733-
return scheduled_reqs
753+
return scheduled_reqs, error_reqs
754+
755+
def _waiting_async_process(self, request: Request) -> None:
756+
"""
757+
Check if async preprocessing is complete for a request.
758+
Args:
759+
request: The request to check
760+
Returns:
761+
None: If an error occurred during preprocessing
762+
True: If preprocessing is still in progress (request should be skipped)
763+
False: If preprocessing is complete (request can be scheduled)
764+
"""
765+
for future in request.async_process_futures:
766+
if future.done():
767+
if request.get("error_message") is not None:
768+
return None
769+
else:
770+
return True
771+
request.async_process_futures = []
772+
return False
773+
774+
def _apply_async_preprocess(self, request: Request) -> None:
775+
request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request))
776+
777+
def _has_features_info(self, task):
778+
inputs = task.multimodal_inputs
779+
if inputs is None or len(inputs) == 0:
780+
return False
781+
782+
if (
783+
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
784+
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
785+
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
786+
):
787+
return True
788+
return False
789+
790+
def _download_features(self, request: Request) -> None:
791+
"""
792+
download multimodal features from bos
793+
Note:
794+
1. this function will be add features for request.multimodal_inputs
795+
2. this function maybe update request.error_message and request.error_code
796+
Args:
797+
request (Request): request object
798+
"""
799+
800+
def download_bos_features(bos_client, features_urls):
801+
result_list = []
802+
for status, feature in download_from_bos(self.bos_client, features_urls):
803+
if status:
804+
llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}")
805+
result_list.append(feature)
806+
else:
807+
error_msg = f"request {request.request_id} download features error: {feature}"
808+
llm_logger.error(error_msg)
809+
return error_msg
810+
return result_list
811+
812+
if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request):
813+
return None
814+
815+
if self.bos_client is None:
816+
self.bos_client = init_bos_client()
817+
818+
inputs = request.multimodal_inputs
819+
if inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0:
820+
result = download_bos_features(self.bos_client, inputs["video_feature_urls"])
821+
if isinstance(result, str): # download error
822+
request.error_message = result
823+
request.error_code = 530
824+
return None
825+
inputs["video_features"] = result
826+
if inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0:
827+
result = download_bos_features(self.bos_client, inputs["image_feature_urls"])
828+
if isinstance(result, str): # download error
829+
request.error_message = result
830+
request.error_code = 530
831+
return None
832+
inputs["image_features"] = result
833+
if inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0:
834+
result = download_bos_features(self.bos_client, inputs["audio_feature_urls"])
835+
if isinstance(result, str): # download error
836+
request.error_message = result
837+
request.error_code = 530
838+
return None
839+
inputs["audio_features"] = result
734840

735841
def get_available_position(self) -> int:
736842
position = 0
@@ -788,6 +894,7 @@ def get_prefix_cached_blocks(self, request: Request):
788894

789895
def add_request(self, request: Request) -> None:
790896
with self.lock:
897+
self._apply_async_preprocess(request)
791898
self.waiting.append(request)
792899
self.requests[request.request_id] = request
793900

0 commit comments

Comments
 (0)