Skip to content

Commit 7454480

Browse files
authored
[Feature] support bos download retry (#5137)
* support bos download retry * update code * update code
1 parent 43097a5 commit 7454480

File tree

5 files changed

+51
-29
lines changed

5 files changed

+51
-29
lines changed

fastdeploy/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,6 @@ def __init__(
550550
self.use_internode_ll_two_stage: bool = False
551551
# disable sequence parallel moe
552552
self.disable_sequence_parallel_moe: bool = False
553-
# enable async download features
554-
self.enable_async_download_features: bool = False
555553

556554
self.pod_ip: str = None
557555
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).

fastdeploy/engine/args_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,6 @@ class EngineArgs:
467467
Url for router server, such as `0.0.0.0:30000`.
468468
"""
469469

470-
enable_async_download_features: bool = False
471-
"""
472-
Flag to enable async download features. Default is False (disabled).
473-
"""
474-
475470
def __post_init__(self):
476471
"""
477472
Post-initialization processing to set default tokenizer if not provided.
@@ -844,12 +839,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
844839
default=EngineArgs.enable_expert_parallel,
845840
help="Enable expert parallelism.",
846841
)
847-
parallel_group.add_argument(
848-
"--enable-async-download-features",
849-
action="store_true",
850-
default=EngineArgs.enable_async_download_features,
851-
help="Enable async download features.",
852-
)
853842

854843
# Load group
855844
load_group = parser.add_argument_group("Load Configuration")

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def _download_features(self, request: Request) -> None:
809809

810810
def download_bos_features(bos_client, features_urls):
811811
result_list = []
812-
for status, feature in download_from_bos(self.bos_client, features_urls):
812+
for status, feature in download_from_bos(self.bos_client, features_urls, retry=1):
813813
if status:
814814
llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}")
815815
result_list.append(feature)
@@ -819,7 +819,7 @@ def download_bos_features(bos_client, features_urls):
819819
return error_msg
820820
return result_list
821821

822-
if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request):
822+
if not self._has_features_info(request):
823823
return None
824824

825825
if self.bos_client is None:

fastdeploy/utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import sys
3030
import tarfile
3131
import time
32+
import traceback
3233
from datetime import datetime
3334
from enum import Enum
3435
from http import HTTPStatus
@@ -976,33 +977,53 @@ def init_bos_client():
976977
return BosClient(cfg)
977978

978979

979-
def download_from_bos(bos_client, bos_links):
980+
def download_from_bos(bos_client, bos_links, retry: int = 0):
980981
"""
981982
Download pickled objects from Baidu Object Storage (BOS).
982983
Args:
983984
bos_client: BOS client instance
984985
bos_links: Single link or list of BOS links in format "bos://bucket-name/path/to/object"
986+
retry: Number of times to retry on failure (only retries on network-related errors)
985987
Yields:
986988
tuple: (success: bool, data: np.ndarray | error_msg: str)
987989
- On success: (True, deserialized_data)
988990
- On failure: (False, error_message) and stops processing remaining links
989991
Security Note:
990992
Uses pickle deserialization. Only use with trusted data sources.
991993
"""
994+
995+
def _bos_download(bos_client, link):
996+
if link.startswith("bos://"):
997+
link = link.replace("bos://", "")
998+
999+
bucket_name = "/".join(link.split("/")[1:-1])
1000+
object_key = link.split("/")[-1]
1001+
return bos_client.get_object_as_string(bucket_name, object_key)
1002+
9921003
if not isinstance(bos_links, list):
9931004
bos_links = [bos_links]
9941005

9951006
for link in bos_links:
9961007
try:
997-
if link.startswith("bos://"):
998-
link = link.replace("bos://", "")
999-
1000-
bucket_name = "/".join(link.split("/")[1:-1])
1001-
object_key = link.split("/")[-1]
1002-
response = bos_client.get_object_as_string(bucket_name, object_key)
1008+
response = _bos_download(bos_client, link)
10031009
yield True, pickle.loads(response)
1004-
except Exception as e:
1005-
yield False, f"link {link} download error: {str(e)}"
1010+
except Exception:
1011+
# Only retry on network-related or timeout exceptions
1012+
exceptions_msg = str(traceback.format_exc())
1013+
1014+
if "request rate is too high" not in exceptions_msg or retry <= 0:
1015+
yield False, f"Failed to download {link}: {exceptions_msg}"
1016+
break
1017+
1018+
for attempt in range(retry):
1019+
try:
1020+
llm_logger.warning(f"Retry attempt {attempt + 1}/{retry} for {link}")
1021+
response = _bos_download(bos_client, link)
1022+
yield True, pickle.loads(response)
1023+
break
1024+
except Exception:
1025+
if attempt == retry - 1: # Last attempt failed
1026+
yield False, f"Failed after {retry} retries for {link}: {str(traceback.format_exc())}"
10061027
break
10071028

10081029

tests/v1/test_resource_manager_v1.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def setUp(self):
2020
max_num_seqs=max_num_seqs,
2121
num_gpu_blocks_override=102,
2222
max_num_batched_tokens=3200,
23-
enable_async_download_features=True,
2423
)
2524
args = asdict(engine_args)
2625

@@ -130,9 +129,9 @@ def test_download_features_image_error(self):
130129
self.manager.bos_client = mock_client
131130
result = self.manager._download_features(self.request)
132131
self.assertIsNone(result)
133-
self.assertEqual(
132+
self.assertIn(
133+
"request test_request download features error",
134134
self.request.error_message,
135-
"request test_request download features error: link bucket-name/path/to/object1 download error: network error",
136135
)
137136
self.assertEqual(self.request.error_code, 530)
138137

@@ -151,12 +150,27 @@ def test_download_features_audio_mixed(self):
151150
self.manager.bos_client = mock_client
152151
result = self.manager._download_features(self.request)
153152
self.assertIsNone(result)
154-
self.assertEqual(
153+
self.assertIn(
154+
"request test_request download features error",
155155
self.request.error_message,
156-
"request test_request download features error: link bucket-name/path/to/object2 download error: timeout",
157156
)
158157
self.assertEqual(self.request.error_code, 530)
159158

159+
def test_download_features_retry(self):
160+
"""Test image feature download with error"""
161+
mock_client = MagicMock()
162+
mock_client.get_object_as_string.side_effect = Exception(
163+
"Your request rate is too high. We have put limits on your bucket."
164+
)
165+
166+
self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}
167+
168+
self.manager.bos_client = mock_client
169+
result = self.manager._download_features(self.request)
170+
self.assertIsNone(result)
171+
self.assertIn("Failed after 1 retries for bos://bucket-name/path/to/object1", self.request.error_message)
172+
self.assertEqual(self.request.error_code, 530)
173+
160174

161175
if __name__ == "__main__":
162176
unittest.main()

0 commit comments

Comments
 (0)