diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 5ce03e49cb7..e19fffb0f1b 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -24,6 +24,8 @@ from sglang.srt.utils import is_hip +is_hip_ = is_hip() + @triton.jit def tanh(x): @@ -506,7 +508,7 @@ def _decode_grouped_att_m_fwd( num_warps = 4 extra_kargs = {} - if is_hip(): + if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} @@ -561,7 +563,7 @@ def _decode_grouped_softmax_reducev_fwd( BLOCK_DMODEL = triton.next_power_of_2(Lv) extra_kargs = {} - if is_hip(): + if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 87c10ed4fb6..56cc439c31e 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -29,6 +29,8 @@ if is_cuda_available: CUDA_CAPABILITY = torch.cuda.get_device_capability() +is_hip_ = is_hip() + @triton.jit def tanh(x): @@ -311,7 +313,7 @@ def extend_attention_fwd( num_stages = 1 extra_kargs = {} - if is_hip(): + if is_hip_: extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} _fwd_kernel[grid]( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3ba311b8c68..b83271f4358 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -242,15 +242,17 @@ def setup_model(self): ) return get_model(vllm_config=vllm_config) except ImportError: - return get_model( - model_config=self.vllm_model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - parallel_config=None, - scheduler_config=None, - lora_config=None, - cache_config=None, - ) + pass + + return get_model( + model_config=self.vllm_model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + parallel_config=None, + scheduler_config=None, + lora_config=None, + cache_config=None, + ) def get_model_config_params(self): sig = inspect.signature(VllmModelConfig.__init__) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index f0dfa8f85a0..a985c8dda9e 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -152,15 +152,7 @@ def test_6_engine_runtime_encode_consistency(self): self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) - def test_7_engine_offline_throughput(self): - server_args = ServerArgs( - model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - ) - bench_args = BenchArgs(num_prompts=10) - result = throughput_test(server_args=server_args, bench_args=bench_args) - self.assertGreater(result["total_throughput"], 3500) - - def test_8_engine_cpu_offload(self): + def test_7_engine_cpu_offload(self): prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -190,6 +182,14 @@ def test_8_engine_cpu_offload(self): print(out2) self.assertEqual(out1, out2) + def test_8_engine_offline_throughput(self): + server_args = ServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + ) + bench_args = BenchArgs(num_prompts=10) + result = throughput_test(server_args=server_args, bench_args=bench_args) + self.assertGreater(result["total_throughput"], 3500) + if __name__ == "__main__": unittest.main()