diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 5353ec1380d..c86474e934a 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -859,6 +859,7 @@ async def benchmark( tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], request_rate: float, + max_concurrency: Optional[int], disable_tqdm: bool, extra_request_body: Dict[str, Any], profile: bool, @@ -868,6 +869,15 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") + # From https://github.com/vllm-project/vllm/pull/9390 + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, pbar=pbar) + print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len = input_requests[0] test_input = RequestFuncInput( @@ -913,7 +923,7 @@ async def benchmark( ) tasks.append( asyncio.create_task( - request_func(request_func_input=request_func_input, pbar=pbar) + limited_request_func(request_func_input=request_func_input, pbar=pbar) ) ) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) @@ -940,6 +950,12 @@ async def benchmark( print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Backend:", backend)) print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print( + "{:<40} {:<10}".format( + "Max reqeuest concurrency:", + max_concurrency if max_concurrency else "not set", + ) + ) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) @@ -1003,6 +1019,7 @@ async def benchmark( "backend": args.backend, "dataset_name": args.dataset_name, "request_rate": request_rate, + "max_concurrency": max_concurrency, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "total_output_tokens_retokenized": metrics.total_output_retokenized, @@ -1201,6 +1218,7 @@ def run_benchmark(args_: argparse.Namespace): tokenizer=tokenizer, input_requests=input_requests, request_rate=args.request_rate, + max_concurrency=args.max_concurrency, disable_tqdm=args.disable_tqdm, extra_request_body=extra_request_body, profile=args.profile, @@ -1220,6 +1238,7 @@ def run_benchmark(args_: argparse.Namespace): tokenizer=tokenizer, input_requests=input_requests, request_rate=rate, + max_concurrency=args.max_concurrency, disable_tqdm=args.disable_tqdm, extra_request_body=extra_request_body, profile=args.profile, @@ -1319,6 +1338,19 @@ def set_ulimit(target_soft_limit=65535): help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--multi",