-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. #30125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shen-shanshan
wants to merge
23
commits into
vllm-project:main
Choose a base branch
from
shen-shanshan:vit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,238
−825
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
82648be
extract mm encoder attention as custom op.
shen-shanshan 3b6bf39
fix
shen-shanshan 8676aa8
update
shen-shanshan 628dbfa
address comments
shen-shanshan 943f8fe
Merge remote-tracking branch 'upstream/main' into vit
Isotr0py 492fe5f
fix
Isotr0py d3ed3b7
use vit ops wrapper
Isotr0py 958bcb4
fix tpu
Isotr0py 5c79209
fix assertion
Isotr0py f4dd34f
fix
Isotr0py 95c9548
fix torch compile
Isotr0py 597ca66
update siglip2navit
Isotr0py f3f8ef7
update paddleocr
Isotr0py b1359d8
update glm4.1v
Isotr0py eccc425
update keye
Isotr0py 29c44fb
fix glm4.1v
Isotr0py c561b30
update dots_ocr
Isotr0py ee58ea5
update ernie45_vl
Isotr0py acba2ca
Merge remote-tracking branch 'upstream/main' into vit
Isotr0py 905b322
add unit tests test_vit_backend_functionality.py; fix siglip and qwen…
tjtanaa 3de126b
fix qwen25 omni multimodal config passing bug
tjtanaa 90673ce
fix get_rope
tjtanaa ee70bf0
add qwen2.5 omni to unittest
tjtanaa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
434 changes: 434 additions & 0 deletions
434
tests/models/multimodal/generation/test_vit_backend_functionality.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,284 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.attention.backends.registry import AttentionBackendEnum | ||
| from vllm.attention.ops.vit_attn_wrappers import ( | ||
| vit_flash_attn_wrapper, | ||
| vit_torch_sdpa_wrapper, | ||
| ) | ||
| from vllm.config import MultiModalConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.model_executor.models.vision import get_vit_attn_backend | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def maybe_get_vit_flash_attn_backend( | ||
| attn_backend: AttentionBackendEnum | None, | ||
| ) -> Callable | None: | ||
| # At this point, | ||
| # we already have the attn_backend, | ||
| # overriding logic is done in the platform-specific implementation. | ||
| # so we don't need to override backend here. | ||
| # Just return the attn_backend and flash_attn_varlen_func. | ||
|
|
||
| if attn_backend == AttentionBackendEnum.FLASH_ATTN: | ||
| from vllm.attention.utils.fa_utils import flash_attn_varlen_func | ||
| elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA: | ||
| from aiter import flash_attn_varlen_func | ||
| else: | ||
| flash_attn_varlen_func = None | ||
|
|
||
| # if attn_backend is TORCH_SDPA, | ||
| # it will reach here and the flash_attn_varlen_func will be None. | ||
| return flash_attn_varlen_func | ||
|
|
||
|
|
||
| @CustomOp.register("mm_encoder_attn") | ||
| class MMEncoderAttention(CustomOp): | ||
| """Multi-headed attention without any cache, used for multimodal encoder.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_heads: int, | ||
| head_size: int, | ||
| scale: float | None = None, | ||
| num_kv_heads: int | None = None, | ||
| prefix: str = "", | ||
| multimodal_config: MultiModalConfig | None = None, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| num_heads: number of attention heads per partition. | ||
| head_size: hidden_size per attention head. | ||
| scale: scale factor. | ||
| num_kv_heads: number of kv heads. | ||
| prefix: This has no effect, it is only here to make it easier to | ||
| swap between Attention and MultiHeadAttention | ||
| multimodal_config: configs for multi-modal. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| self.num_heads = num_heads | ||
| self.head_size = head_size | ||
| self.scale = scale | ||
| self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
| self.layer_name = prefix | ||
|
|
||
| assert self.num_heads % self.num_kv_heads == 0, ( | ||
| f"num_heads ({self.num_heads}) is not " | ||
| f"divisible by num_kv_heads ({self.num_kv_heads})" | ||
| ) | ||
| self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
|
||
| # During model initialization, the default dtype is set as the model | ||
| # weight and activation dtype. | ||
| dtype = torch.get_default_dtype() | ||
|
|
||
| # Try to get vision attention backend from multimodal_config. | ||
| attn_backend_override = None | ||
| if multimodal_config is not None: | ||
| attn_backend_override = multimodal_config.mm_encoder_attn_backend | ||
|
|
||
| # Get device-specific vision attention backend. | ||
| self.attn_backend = get_vit_attn_backend( | ||
| head_size=head_size, | ||
| dtype=dtype, | ||
| attn_backend_override=attn_backend_override, | ||
| ) | ||
|
|
||
| self.is_flash_attn_backend = self.attn_backend in { | ||
| AttentionBackendEnum.FLASH_ATTN, | ||
| AttentionBackendEnum.ROCM_AITER_FA, | ||
| } | ||
|
|
||
| self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( | ||
| self.attn_backend, | ||
| ) | ||
|
|
||
| logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") | ||
|
|
||
| @classmethod | ||
| def enabled(cls) -> bool: | ||
| return True | ||
|
|
||
| def reshape_qkv_to_4d( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| bsz: int, | ||
| q_len: int, | ||
| kv_len: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Reshape query, key, value to 4D tensors: | ||
| (batch_size, seq_len, num_heads, head_size) | ||
| """ | ||
| query = query.view(bsz, q_len, self.num_heads, self.head_size) | ||
| key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) | ||
| value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) | ||
|
|
||
| if (num_repeat := self.num_queries_per_kv) > 1: | ||
| # Handle MQA and GQA | ||
| key = torch.repeat_interleave(key, num_repeat, dim=2) | ||
| value = torch.repeat_interleave(value, num_repeat, dim=2) | ||
|
|
||
| return query, key, value | ||
|
|
||
| def reshape_qkv_to_3d( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| bsz: int, | ||
| q_len: int, | ||
| kv_len: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Reshape query, key, value to 3D tensors: | ||
| (batch_size * seq_len, num_heads, head_size) | ||
| """ | ||
| query = query.view(bsz * q_len, self.num_heads, self.head_size) | ||
| key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) | ||
| value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) | ||
|
|
||
| if (num_repeat := self.num_queries_per_kv) > 1: | ||
| # Handle MQA and GQA | ||
| key = torch.repeat_interleave(key, num_repeat, dim=1) | ||
| value = torch.repeat_interleave(value, num_repeat, dim=1) | ||
|
|
||
| return query, key, value | ||
|
|
||
| def _forward_sdpa( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| # TODO(Isotr0py): Migrate MultiHeadAttention | ||
| assert cu_seqlens is not None | ||
|
|
||
| bsz, q_len = query.size()[:2] | ||
| kv_len = key.size(1) | ||
|
|
||
| query, key, value = self.reshape_qkv_to_4d( | ||
| query, key, value, bsz, q_len, kv_len | ||
| ) | ||
|
|
||
| output = vit_torch_sdpa_wrapper( | ||
| q=query, | ||
| k=key, | ||
| v=value, | ||
| cu_seqlens=cu_seqlens, | ||
| ) | ||
| return output | ||
|
|
||
| def _forward_fa( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention | ||
| ) -> torch.Tensor: | ||
| assert self.flash_attn_varlen_func is not None, ( | ||
| "Flash attention function is not set." | ||
| ) | ||
| # # TODO(Isotr0py): Migrate MultiHeadAttention | ||
| assert cu_seqlens is not None and max_seqlen is not None | ||
|
|
||
| bsz = query.shape[0] | ||
|
|
||
| output = vit_flash_attn_wrapper( | ||
| q=query, | ||
| k=key, | ||
| v=value, | ||
| cu_seqlens=cu_seqlens, | ||
| max_seqlen=max_seqlen, | ||
| batch_size=bsz, | ||
| is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), | ||
| ) | ||
| return output | ||
|
|
||
| def forward_native( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention | ||
| ) -> torch.Tensor: | ||
| return self._forward_sdpa(query, key, value, cu_seqlens) | ||
|
|
||
| def forward_cuda( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention | ||
| ) -> torch.Tensor: | ||
| if self.is_flash_attn_backend: | ||
| return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) | ||
| elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: | ||
| return self._forward_sdpa(query, key, value, cu_seqlens) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported multi-modal encoder attention backend for CUDA: " | ||
| f"{self.attn_backend}." | ||
| ) | ||
|
|
||
| def forward_cpu( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention | ||
| ) -> torch.Tensor: | ||
| return self._forward_sdpa(query, key, value, cu_seqlens) | ||
|
|
||
| def forward_xpu( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention | ||
| ) -> torch.Tensor: | ||
| assert self.is_flash_attn_backend, ( | ||
| "XPU only supports FLASH_ATTN for vision attention." | ||
| ) | ||
| return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) | ||
|
|
||
| def forward_tpu( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
| max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention | ||
| ) -> torch.Tensor: | ||
| assert self.attn_backend == AttentionBackendEnum.PALLAS, ( | ||
| f"MMEncoderAttention on TPU only supports PALLAS backend, " | ||
| f"but got {self.attn_backend}." | ||
| ) | ||
| if cu_seqlens is None: | ||
| query, key, value = (x.transpose(1, 2) for x in (query, key, value)) | ||
| from torch_xla.experimental.custom_kernel import flash_attention | ||
|
|
||
| out = flash_attention(query, key, value, sm_scale=self.scale) | ||
| out = out.transpose(1, 2) | ||
| return out | ||
| logger.warning_once( | ||
| "PALLAS backend with cu_seqlens is not supported for ViT yet. ", | ||
| "Falling back to SDPA implementation.", | ||
| ) | ||
| return self._forward_sdpa(query, key, value, cu_seqlens) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.