@@ -727,9 +727,9 @@ def sample_generated_shared_prefix_requests(
727
727
total_input_tokens = 0
728
728
total_output_tokens = 0
729
729
730
- for group_idx in range (num_groups ):
730
+ for group_idx in tqdm ( range (num_groups ), desc = "Generating system prompt" ):
731
731
system_prompt = system_prompts [group_idx ]
732
- for prompt_idx in range (prompts_per_group ):
732
+ for prompt_idx in tqdm ( range (prompts_per_group ), desc = "Generating questions" ):
733
733
question = questions [group_idx * prompts_per_group + prompt_idx ]
734
734
full_prompt = f"{ system_prompt } \n \n { question } "
735
735
prompt_len = len (tokenizer .encode (full_prompt ))
@@ -859,6 +859,7 @@ async def benchmark(
859
859
tokenizer : PreTrainedTokenizerBase ,
860
860
input_requests : List [Tuple [str , int , int ]],
861
861
request_rate : float ,
862
+ max_concurrency : Optional [int ],
862
863
disable_tqdm : bool ,
863
864
extra_request_body : Dict [str , Any ],
864
865
profile : bool ,
@@ -868,6 +869,15 @@ async def benchmark(
868
869
else :
869
870
raise ValueError (f"Unknown backend: { backend } " )
870
871
872
+ # From https://github.com/vllm-project/vllm/pull/9390
873
+ semaphore = asyncio .Semaphore (max_concurrency ) if max_concurrency else None
874
+
875
+ async def limited_request_func (request_func_input , pbar ):
876
+ if semaphore is None :
877
+ return await request_func (request_func_input = request_func_input , pbar = pbar )
878
+ async with semaphore :
879
+ return await request_func (request_func_input = request_func_input , pbar = pbar )
880
+
871
881
print ("Starting initial single prompt test run..." )
872
882
test_prompt , test_prompt_len , test_output_len = input_requests [0 ]
873
883
test_input = RequestFuncInput (
@@ -913,7 +923,7 @@ async def benchmark(
913
923
)
914
924
tasks .append (
915
925
asyncio .create_task (
916
- request_func (request_func_input = request_func_input , pbar = pbar )
926
+ limited_request_func (request_func_input = request_func_input , pbar = pbar )
917
927
)
918
928
)
919
929
outputs : List [RequestFuncOutput ] = await asyncio .gather (* tasks )
@@ -940,6 +950,12 @@ async def benchmark(
940
950
print ("\n {s:{c}^{n}}" .format (s = " Serving Benchmark Result " , n = 50 , c = "=" ))
941
951
print ("{:<40} {:<10}" .format ("Backend:" , backend ))
942
952
print ("{:<40} {:<10}" .format ("Traffic request rate:" , request_rate ))
953
+ print (
954
+ "{:<40} {:<10}" .format (
955
+ "Max reqeuest concurrency:" ,
956
+ max_concurrency if max_concurrency else "not set" ,
957
+ )
958
+ )
943
959
print ("{:<40} {:<10}" .format ("Successful requests:" , metrics .completed ))
944
960
print ("{:<40} {:<10.2f}" .format ("Benchmark duration (s):" , benchmark_duration ))
945
961
print ("{:<40} {:<10}" .format ("Total input tokens:" , metrics .total_input ))
@@ -1003,6 +1019,7 @@ async def benchmark(
1003
1019
"backend" : args .backend ,
1004
1020
"dataset_name" : args .dataset_name ,
1005
1021
"request_rate" : request_rate ,
1022
+ "max_concurrency" : max_concurrency ,
1006
1023
"total_input_tokens" : metrics .total_input ,
1007
1024
"total_output_tokens" : metrics .total_output ,
1008
1025
"total_output_tokens_retokenized" : metrics .total_output_retokenized ,
@@ -1090,6 +1107,10 @@ def run_benchmark(args_: argparse.Namespace):
1090
1107
global args
1091
1108
args = args_
1092
1109
1110
+ # Set default value for max_concurrency if not present
1111
+ if not hasattr (args , "max_concurrency" ):
1112
+ args .max_concurrency = None
1113
+
1093
1114
# Set global environments
1094
1115
set_ulimit ()
1095
1116
random .seed (args .seed )
@@ -1201,6 +1222,7 @@ def run_benchmark(args_: argparse.Namespace):
1201
1222
tokenizer = tokenizer ,
1202
1223
input_requests = input_requests ,
1203
1224
request_rate = args .request_rate ,
1225
+ max_concurrency = args .max_concurrency ,
1204
1226
disable_tqdm = args .disable_tqdm ,
1205
1227
extra_request_body = extra_request_body ,
1206
1228
profile = args .profile ,
@@ -1220,6 +1242,7 @@ def run_benchmark(args_: argparse.Namespace):
1220
1242
tokenizer = tokenizer ,
1221
1243
input_requests = input_requests ,
1222
1244
request_rate = rate ,
1245
+ max_concurrency = args .max_concurrency ,
1223
1246
disable_tqdm = args .disable_tqdm ,
1224
1247
extra_request_body = extra_request_body ,
1225
1248
profile = args .profile ,
@@ -1319,6 +1342,19 @@ def set_ulimit(target_soft_limit=65535):
1319
1342
help = "Number of requests per second. If this is inf, then all the requests are sent at time 0. "
1320
1343
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf." ,
1321
1344
)
1345
+ parser .add_argument (
1346
+ "--max-concurrency" ,
1347
+ type = int ,
1348
+ default = None ,
1349
+ help = "Maximum number of concurrent requests. This can be used "
1350
+ "to help simulate an environment where a higher level component "
1351
+ "is enforcing a maximum number of concurrent requests. While the "
1352
+ "--request-rate argument controls the rate at which requests are "
1353
+ "initiated, this argument will control how many are actually allowed "
1354
+ "to execute at a time. This means that when used in combination, the "
1355
+ "actual request rate may be lower than specified with --request-rate, "
1356
+ "if the server is not processing requests fast enough to keep up." ,
1357
+ )
1322
1358
parser .add_argument ("--seed" , type = int , default = 1 , help = "The random seed." )
1323
1359
parser .add_argument (
1324
1360
"--multi" ,
0 commit comments