Skip to content

Conversation

@shen-shanshan
Copy link
Contributor

@shen-shanshan shen-shanshan commented Dec 5, 2025

Purpose

To avoid maintaining a variety of modeling files in vllm-ascend, we propose to remove all files in models dir in vllm-ascend. After this, the only thing a vllm plugin need to do is just registering their custom device-specific OOT ops to vllm when adding a new model. To achieve this, there are some refactors need to be done both in vllm and vllm-ascend, such as extracting some general layers as CustomOp, find more details at vllm-project/vllm-ascend#4084.

Following #27919 and #27147, this PR has unified the getting logic of vit_attn_backend and extracted MMEncoderAttention as a CustomOp.

To be specific, vision attention backend should only be checked and overwritten in the platform-specific implementation. We should not overwrite this logic in any other places, such as model_executor/models/<model_name>.py. In addition, I have moved scattered forward dispatch logic into this CustomOp to avoid verification for current_platform in any other places.

To minimize the influence, I only replaced the backend of QwenVisionAttention with this CustomOp and have tested this PR both on Ascend A2 NPU and NVIDIA A100 GPU (TODO). I will modify other modeling files and delete the old MultiHeadAttention in the future if this PR could be merged.

Test Plan

Test Result

✅ Ascend A2 NPU

Run Qwen2.5-VL:

vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384 \
--max-num-batched-tokens 16384 \
--tensor-parallel-size 2 \
--enforce-eager

Output:

{"id":"chatcmpl-b4e3053f30ab2442","object":"chat.completion","created":1764922950,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image is \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}

Run Qwen3-VL:

vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384 \
--tensor-parallel-size 2 \
--enforce-eager

Output:

{"id":"chatcmpl-97571fbda8267bd1","object":"chat.completion","created":1764923306,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}

NVIDIA A100 GPU

TO BE DONE...


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: shen-shanshan [email protected]
Co-authored-by: Isotr0py [email protected]
Co-authored-by: tjtanaa [email protected]

@mergify mergify bot added qwen Related to Qwen models nvidia rocm Related to AMD ROCm tpu Related to Google TPUs labels Dec 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a good step towards refactoring the attention mechanisms and making the codebase more modular by introducing MMEncoderAttention as a CustomOp. The unification of the vision attention backend logic is also a welcome improvement.

I've found a critical bug in vllm/attention/layer.py where a variable was renamed but not all its usages were updated, which would cause a runtime error. I've also pointed out an instance of code duplication in the new mm_encoder_attention.py file that should be addressed to improve maintainability.

Once these issues are resolved, this PR will be a solid contribution to the project's architecture.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

self.attn_backend = attn_backend
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,

P1 Badge PaddleOCR vision attention calls helper with outdated signature

maybe_get_vit_flash_attn_backend now only accepts the backend and returns a single function, but the PaddleOCR vision attention still unpacks two return values and passes attn_backend_override. Instantiating this module will now raise TypeError: maybe_get_vit_flash_attn_backend() got an unexpected keyword argument 'attn_backend_override', preventing the model from loading.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@shen-shanshan shen-shanshan marked this pull request as draft December 5, 2025 09:51
@mergify
Copy link

mergify bot commented Dec 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shen-shanshan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 8, 2025
@shen-shanshan shen-shanshan marked this pull request as ready for review December 8, 2025 12:11
@mergify mergify bot removed the needs-rebase label Dec 8, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@shen-shanshan
Copy link
Contributor Author

CC @Isotr0py @tjtanaa @DarkLight1337

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this effort! I left some initial comments, and will further look into this tomorrow. PTAL!

Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: tjtanaa <[email protected]>

Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
@mergify
Copy link

mergify bot commented Dec 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shen-shanshan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 10, 2025
@mergify mergify bot removed the needs-rebase label Dec 10, 2025
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
@Isotr0py
Copy link
Member

Isotr0py commented Dec 10, 2025

TODO: Models to check (both FA/SDPA)

  • DotsOCR (dots_ocr.py)
  • Ernie45_vl (ernie45_vl.py)
  • GLM4.1-V (glm4_1v.py)
  • Keye (keye.py)
  • PaddleOCR (paddleocr_vl.py)
  • Qwen2-VL (qwen2_vl.py)
  • Qwen2.5-VL (qwen2_5_vl.py)
  • Qwen3-VL (qwen3_vl.py)
  • Ovis2.5 (siglip2navit.py)

(We really need UT for get_multimodal_embeddings per model, manual validation costs too much efforts 😅 )

@shen-shanshan
Copy link
Contributor Author

TODO: Models to check (both FA/SDPA)

  • DotsOCR (dots_ocr.py)
  • Ernie45_vl (ernie45_vl.py)
  • GLM4.1-V (glm4_1v.py)
  • Keye (keye.py)
  • PaddleOCR (paddleocr_vl.py)
  • Qwen2-VL (qwen2_vl.py)
  • Qwen2.5-VL (qwen2_5_vl.py)
  • Qwen3-VL (qwen3_vl.py)
  • Ovis2.5 (siglip2navit.py)

(We really need UT for get_multimodal_embeddings per model, manual validation costs too much efforts 😅 )

I can help add related UT.

@Isotr0py
Copy link
Member

Thanks, let's add the test in a following PR before we migrate MultiHeadAttention to MMEncoderAttention.

Only several models use Qwen2-style attention and have small variations, so manual validation is still fine for this PR :)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia qwen Related to Qwen models rocm Related to AMD ROCm tpu Related to Google TPUs

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants