diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py new file mode 100644 index 00000000000..99829a8136e --- /dev/null +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import einops +import torch +import torch.nn.functional as F +import torch_npu +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig + +import vllm_ascend.envs as envs_ascend + +MIN_PAD_SIZE = 64 # min_size to pad weight +MAX_PAD_SIZE = 128 # max_size to pad weight + + +class AscendMMEncoderAttention(MMEncoderAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float | None = None, + num_kv_heads: int | None = None, + # This has no effect, it is only here to make it easier to swap + # between Attention and MultiHeadAttention + prefix: str = "", + multimodal_config: MultiModalConfig | None = None, + ) -> None: + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + prefix=prefix, + multimodal_config=multimodal_config, + ) + + def forward_oot( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor + | None = None, # Only used for Flash Attention + ): + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + + # q/k/v: [b, s, head, head_dim] -> [b * s, head, head_dim] + q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) + + enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL + and self.head_size > MIN_PAD_SIZE + and self.head_size < MAX_PAD_SIZE) + + if enable_pad: + origin_shape = q.shape[-1] + pad_len = MAX_PAD_SIZE - origin_shape + # q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE] + q = F.pad(q, (0, pad_len), mode="constant", value=0) + k = F.pad(k, (0, pad_len), mode="constant", value=0) + v = F.pad(v, (0, pad_len), mode="constant", value=0) + + context_layer = torch.empty_like(q) + cu_seqlens = torch.diff(cu_seqlens).to("cpu") + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.head_size**-0.5, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=context_layer, + ) + + if enable_pad: + context_layer = context_layer[..., :origin_shape] + + context_layer = einops.rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=bsz).contiguous() + return context_layer diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index 1c2f356bd41..f4edc118a8f 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -18,7 +18,6 @@ import einops import torch import torch.nn as nn -import torch.nn.functional as F import torch_npu from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionAttention, Qwen2_5_VLForConditionalGeneration, @@ -26,7 +25,6 @@ from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import set_ascend_forward_context MIN_PAD_SIZE = 64 # min_size to pad weight @@ -47,7 +45,6 @@ def forward( x, _ = self.qkv(x) seq_len, batch_size, _ = x.shape - # Split q k v. qkv = einops.rearrange( x, "s b (three head head_dim) -> b s three head head_dim", @@ -55,10 +52,6 @@ def forward( head=self.num_attention_heads_per_partition, ) q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] - origin_shape = q.shape[-1] - - # Convert cumulative tensor to intervals and move it to cpu. - cu_seqlens = torch.diff(cu_seqlens).to("cpu") cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1) sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1) @@ -67,43 +60,14 @@ def forward( q = torch_npu.npu_rotary_mul(q, cos, sin) k = torch_npu.npu_rotary_mul(k, cos, sin) - q, k, v = [ - einops.rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL - and self.hidden_size_per_attention_head > MIN_PAD_SIZE - and self.hidden_size_per_attention_head < MAX_PAD_SIZE) - - if enable_pad: - pad_len = MAX_PAD_SIZE - origin_shape - # q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE] - q = F.pad(q, (0, pad_len), mode="constant", value=0) - k = F.pad(k, (0, pad_len), mode="constant", value=0) - v = F.pad(v, (0, pad_len), mode="constant", value=0) - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( + context_layer = self.attn( query=q, key=k, value=v, - seq_len=cu_seqlens, - scale_value=self.hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) - if enable_pad: - context_layer = context_layer[..., :origin_shape] - - context_layer = einops.rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - output, _ = self.proj(context_layer) return output diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 85031bf63a9..5dbcfe763eb 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -666,6 +666,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): AscendReplicatedLinear, AscendRowParallelLinear) from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention + from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding, AscendRotaryEmbedding, AscendYaRNRotaryEmbedding) @@ -694,6 +695,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "FusedMoE": AscendFusedMoE, "SharedFusedMoE": AscendSharedFusedMoE, "MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention, + "MMEncoderAttention": AscendMMEncoderAttention, } for name, op_cls in REGISTERED_ASCEND_OPS.items():