Skip to content

Commit

Permalink
[CI] Fix test cases (#2137)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 23, 2024
1 parent c5f8650 commit a78d8f8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from sglang.srt.utils import is_hip

is_hip_ = is_hip()


@triton.jit
def tanh(x):
Expand Down Expand Up @@ -501,7 +503,7 @@ def _decode_grouped_att_m_fwd(
num_warps = 4

extra_kargs = {}
if is_hip():
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
Expand Down Expand Up @@ -557,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lv)

extra_kargs = {}
if is_hip():
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()

is_hip_ = is_hip()


@triton.jit
def tanh(x):
Expand Down Expand Up @@ -311,7 +313,7 @@ def extend_attention_fwd(
num_stages = 1

extra_kargs = {}
if is_hip():
if is_hip_:
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}

_fwd_kernel[grid](
Expand Down
20 changes: 11 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,17 @@ def setup_model(self):
)
return get_model(vllm_config=vllm_config)
except ImportError:
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
pass

return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)

def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
Expand Down
18 changes: 9 additions & 9 deletions test/srt/test_srt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,7 @@ def test_6_engine_runtime_encode_consistency(self):

self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))

def test_7_engine_offline_throughput(self):
server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)

def test_8_engine_cpu_offload(self):
def test_7_engine_cpu_offload(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

Expand Down Expand Up @@ -190,6 +182,14 @@ def test_8_engine_cpu_offload(self):
print(out2)
self.assertEqual(out1, out2)

def test_8_engine_offline_throughput(self):
server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)


if __name__ == "__main__":
unittest.main()

0 comments on commit a78d8f8

Please sign in to comment.