Skip to content

Commit f6ce5de

Browse files
committed
move is_cuda_graph
1 parent d9189fe commit f6ce5de

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

vllm_ascend/patch/worker/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
if HAS_TRITON:
2121
import vllm_ascend.patch.worker.patch_triton
2222

23+
from vllm.config import (CUDAGraphMode, get_current_vllm_config)
24+
is_cuda_graph = get_current_vllm_config().compilation_config.cudagraph_mode != CUDAGraphMode.NONE
25+
2326
# isort: off
2427
import vllm_ascend.patch.platform.patch_sched_yield # noqa
2528
import vllm_ascend.patch.worker.patch_distributed # noqa

vllm_ascend/patch/worker/patch_qwen3_next.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet
2525
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
2626
from vllm.triton_utils import tl, triton
27-
from vllm.config import (CUDAGraphMode, get_current_vllm_config)
28-
27+
from . import is_cuda_graph
2928

3029
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
3130
def forward(
@@ -48,7 +47,6 @@ def forward(
4847
projected_states_ba, _ = self.in_proj_ba(hidden_states)
4948
# triton grid should be less than 66536
5049
divide_grid=projected_states_qkvz.shape[0]*triton.cdiv(self.num_k_heads, self.tp_size)
51-
is_cuda_graph = get_current_vllm_config().compilation_config.cudagraph_mode != CUDAGraphMode.NONE
5250
if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph and divide_grid < 65536:
5351
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
5452
projected_states_qkvz,

0 commit comments

Comments
 (0)