Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,6 @@ def __init__(
self.use_internode_ll_two_stage: bool = False
# disable sequence parallel moe
self.disable_sequence_parallel_moe: bool = False
# enable async download features
self.enable_async_download_features: bool = False

self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
Expand Down
11 changes: 0 additions & 11 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,6 @@ class EngineArgs:
Url for router server, such as `0.0.0.0:30000`.
"""

enable_async_download_features: bool = False
"""
Flag to enable async download features. Default is False (disabled).
"""

def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
Expand Down Expand Up @@ -854,12 +849,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.",
)
parallel_group.add_argument(
"--enable-async-download-features",
action="store_true",
default=EngineArgs.enable_async_download_features,
help="Enable async download features.",
)

# Load group
load_group = parser.add_argument_group("Load Configuration")
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def _download_features(self, request: Request) -> None:

def download_bos_features(bos_client, features_urls):
result_list = []
for status, feature in download_from_bos(self.bos_client, features_urls):
for status, feature in download_from_bos(self.bos_client, features_urls, retry=1):
if status:
llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}")
result_list.append(feature)
Expand All @@ -809,7 +809,7 @@ def download_bos_features(bos_client, features_urls):
return error_msg
return result_list

if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request):
if not self._has_features_info(request):
return None

if self.bos_client is None:
Expand Down
39 changes: 30 additions & 9 deletions fastdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import sys
import tarfile
import time
import traceback
from datetime import datetime
from enum import Enum
from http import HTTPStatus
Expand Down Expand Up @@ -976,33 +977,53 @@ def init_bos_client():
return BosClient(cfg)


def download_from_bos(bos_client, bos_links):
def download_from_bos(bos_client, bos_links, retry: int = 0):
"""
Download pickled objects from Baidu Object Storage (BOS).
Args:
bos_client: BOS client instance
bos_links: Single link or list of BOS links in format "bos://bucket-name/path/to/object"
retry: Number of times to retry on failure (only retries on network-related errors)
Yields:
tuple: (success: bool, data: np.ndarray | error_msg: str)
- On success: (True, deserialized_data)
- On failure: (False, error_message) and stops processing remaining links
Security Note:
Uses pickle deserialization. Only use with trusted data sources.
"""

def _bos_download(bos_client, link):
if link.startswith("bos://"):
link = link.replace("bos://", "")

bucket_name = "/".join(link.split("/")[1:-1])
object_key = link.split("/")[-1]
return bos_client.get_object_as_string(bucket_name, object_key)

if not isinstance(bos_links, list):
bos_links = [bos_links]

for link in bos_links:
try:
if link.startswith("bos://"):
link = link.replace("bos://", "")

bucket_name = "/".join(link.split("/")[1:-1])
object_key = link.split("/")[-1]
response = bos_client.get_object_as_string(bucket_name, object_key)
response = _bos_download(bos_client, link)
yield True, pickle.loads(response)
except Exception as e:
yield False, f"link {link} download error: {str(e)}"
except Exception:
# Only retry on network-related or timeout exceptions
exceptions_msg = str(traceback.format_exc())

if "request rate is too high" not in exceptions_msg or retry <= 0:
yield False, f"Failed to download {link}: {exceptions_msg}"
break

for attempt in range(retry):
try:
llm_logger.warning(f"Retry attempt {attempt + 1}/{retry} for {link}")
response = _bos_download(bos_client, link)
yield True, pickle.loads(response)
break
except Exception:
if attempt == retry - 1: # Last attempt failed
yield False, f"Failed after {retry} retries for {link}: {str(traceback.format_exc())}"
break


Expand Down
24 changes: 19 additions & 5 deletions tests/v1/test_resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def setUp(self):
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=102,
max_num_batched_tokens=3200,
enable_async_download_features=True,
)
args = asdict(engine_args)

Expand Down Expand Up @@ -130,9 +129,9 @@ def test_download_features_image_error(self):
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertEqual(
self.assertIn(
"request test_request download features error",
self.request.error_message,
"request test_request download features error: link bucket-name/path/to/object1 download error: network error",
)
self.assertEqual(self.request.error_code, 530)

Expand All @@ -151,12 +150,27 @@ def test_download_features_audio_mixed(self):
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertEqual(
self.assertIn(
"request test_request download features error",
self.request.error_message,
"request test_request download features error: link bucket-name/path/to/object2 download error: timeout",
)
self.assertEqual(self.request.error_code, 530)

def test_download_features_retry(self):
"""Test image feature download with error"""
mock_client = MagicMock()
mock_client.get_object_as_string.side_effect = Exception(
"Your request rate is too high. We have put limits on your bucket."
)

self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}

self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertIn("Failed after 1 retries for bos://bucket-name/path/to/object1", self.request.error_message)
self.assertEqual(self.request.error_code, 530)


if __name__ == "__main__":
unittest.main()
Loading