-
Notifications
You must be signed in to change notification settings - Fork 655
add qwen3_next ops: fused_qkvzba_split_reshape #4747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces new Triton-based operators, fused_qkvzba_split_reshape and rope_forward_triton, to optimize performance for the qwen3_next model on Ascend hardware. The changes include a new Triton kernel for fused QKVZBA splitting and reshaping, and a new RoPE implementation using Triton. The review identifies two critical issues: a potential bug in rotary_embedding.py that might break graph fusion due to a removed return value assignment, and an import typo in patch_qwen3_next.py that will lead to a runtime error.
vllm_ascend/ops/rotary_embedding.py
Outdated
| key = key.contiguous().view(1, key.shape[0], -1, self.head_size) | ||
| # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. | ||
| # This method requires head_size and rotary_dim equal 128 and neox_style is True | ||
| torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation assigned the result of torch_npu.npu_apply_rotary_pos_emb back to query, key. A comment in the original code indicated this was crucial for graph fusion to work correctly, even though the function modifies tensors in-place. This assignment has been removed. If this is no longer a concern, it would be helpful to add a comment explaining why. Otherwise, this could be a critical bug and the assignment should be restored.
| torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) | |
| query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) |
| from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet | ||
| from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat | ||
| from vllm.triton_utils import tl, triton | ||
| from vll.config import (CUDAGraphMode, get_current_vllm_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?