From 30a0babe12a7f2f588a98ce706eb7ab1c8cac73f Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sun, 24 Nov 2024 15:20:13 +0800 Subject: [PATCH] fix ci --- python/sglang/bench_offline_throughput.py | 1 + test/srt/test_eagle_infer.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f1c4e8f9e18..b64f1ac3986 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -282,6 +282,7 @@ def throughput_test( if bench_args.result_filename: with open(bench_args.result_filename, "a") as fout: fout.write(json.dumps(result) + "\n") + backend.shutdown() print( "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=") diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 41c9af18781..040966bea48 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -59,10 +59,19 @@ def test_2_eagle_offline_throughput(self): bench_args = BenchArgs(num_prompts=10) result_eagle = throughput_test(server_args=server_args, bench_args=bench_args) + server_args = ServerArgs( + model_path="meta-llama/Llama-2-7b-chat-hf", + ) + result_naive = throughput_test(server_args=server_args, bench_args=bench_args) + print("==== Throughput EAGLE ====") print(result_eagle["total_throughput"]) + print("==== Throughput Baseline ====") + print(result_naive["total_throughput"]) - self.assertGreater(result_eagle["total_throughput"], 1200.0) + self.assertGreater( + result_eagle["total_throughput"], result_naive["total_throughput"] * 1.5 + ) if __name__ == "__main__":