Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions vllm_ascend/ops/mm_encoder_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import einops
import torch
import torch.nn.functional as F
import torch_npu
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]

Check failure on line 22 in vllm_ascend/ops/mm_encoder_attention.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.attention.layers.mm_encoder_attention" [import-not-found]
from vllm.config import MultiModalConfig

import vllm_ascend.envs as envs_ascend

MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight


class AscendMMEncoderAttention(MMEncoderAttention):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
prefix=prefix,
multimodal_config=multimodal_config,
)

def forward_oot(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor
| None = None, # Only used for Flash Attention
):
bsz, q_len = query.size()[:2]
kv_len = key.size(1)

# q/k/v: [b, s, head, head_dim] -> [b * s, head, head_dim]
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)

enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL
and self.head_size > MIN_PAD_SIZE
and self.head_size < MAX_PAD_SIZE)

if enable_pad:
origin_shape = q.shape[-1]
pad_len = MAX_PAD_SIZE - origin_shape
# q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
q = F.pad(q, (0, pad_len), mode="constant", value=0)
k = F.pad(k, (0, pad_len), mode="constant", value=0)
v = F.pad(v, (0, pad_len), mode="constant", value=0)

context_layer = torch.empty_like(q)
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The cu_seqlens tensor is used without checking if it is None. The function signature allows cu_seqlens to be None, which would cause a TypeError when torch.diff is called. This could lead to a runtime crash. Please add a check to ensure cu_seqlens is not None before using it.

Suggested change
cu_seqlens = torch.diff(cu_seqlens).to("cpu")
if cu_seqlens is None:
raise ValueError("cu_seqlens cannot be None for AscendMMEncoderAttention")
cu_seqlens = torch.diff(cu_seqlens).to("cpu")


# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
query=q,
key=k,
value=v,
seq_len=cu_seqlens,
scale_value=self.head_size**-0.5,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The number of key-value heads (num_kv_heads) is hardcoded to self.num_heads. This is incorrect for models that use Grouped-Query Attention (GQA) or Multi-Query Attention (MQA) where the number of key-value heads is different from the number of query heads. The num_kv_heads parameter is passed during initialization and should be available as self.num_kv_heads.

Suggested change
num_kv_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,

out=context_layer,
)

if enable_pad:
context_layer = context_layer[..., :origin_shape]

context_layer = einops.rearrange(context_layer,
"(b s) h d -> s b (h d)",
b=bsz).contiguous()
return context_layer
42 changes: 3 additions & 39 deletions vllm_ascend/patch/worker/patch_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs)
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention
from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import set_ascend_forward_context

MIN_PAD_SIZE = 64 # min_size to pad weight
Expand All @@ -47,18 +45,13 @@ def forward(
x, _ = self.qkv(x)
seq_len, batch_size, _ = x.shape

# Split q k v.
qkv = einops.rearrange(
x,
"s b (three head head_dim) -> b s three head head_dim",
three=3,
head=self.num_attention_heads_per_partition,
)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
origin_shape = q.shape[-1]

# Convert cumulative tensor to intervals and move it to cpu.
cu_seqlens = torch.diff(cu_seqlens).to("cpu")

cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1)
sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1)
Expand All @@ -67,43 +60,14 @@ def forward(
q = torch_npu.npu_rotary_mul(q, cos, sin)
k = torch_npu.npu_rotary_mul(k, cos, sin)

q, k, v = [
einops.rearrange(x, "b s h d -> (b s) h d").contiguous()
for x in (q, k, v)
]

enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL
and self.hidden_size_per_attention_head > MIN_PAD_SIZE
and self.hidden_size_per_attention_head < MAX_PAD_SIZE)

if enable_pad:
pad_len = MAX_PAD_SIZE - origin_shape
# q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
q = F.pad(q, (0, pad_len), mode="constant", value=0)
k = F.pad(k, (0, pad_len), mode="constant", value=0)
v = F.pad(v, (0, pad_len), mode="constant", value=0)

context_layer = torch.empty_like(q)

# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
context_layer = self.attn(
query=q,
key=k,
value=v,
seq_len=cu_seqlens,
scale_value=self.hidden_size_per_attention_head**-0.5,
num_heads=self.num_attention_heads_per_partition,
num_kv_heads=self.num_attention_heads_per_partition,
out=context_layer,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

if enable_pad:
context_layer = context_layer[..., :origin_shape]

context_layer = einops.rearrange(context_layer,
"(b s) h d -> s b (h d)",
b=batch_size).contiguous()

output, _ = self.proj(context_layer)
return output

Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
AscendReplicatedLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
Expand Down Expand Up @@ -694,6 +695,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
"MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention,
"MMEncoderAttention": AscendMMEncoderAttention,
}

for name, op_cls in REGISTERED_ASCEND_OPS.items():
Expand Down
Loading