Skip to content

[WIP] support flashinfer_mla#113

Draft
ILikeIneine wants to merge 1 commit into
v0.10.2-devfrom
support-flashinfermla
Draft

[WIP] support flashinfer_mla#113
ILikeIneine wants to merge 1 commit into
v0.10.2-devfrom
support-flashinfermla

Conversation

@ILikeIneine
Copy link
Copy Markdown
Member

Purpose

This pr is for supporting flashinfer_mla on flashinfer+metax

Test Plan

Test Result

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Hank <hcc.mayday@gmail.com>
@ILikeIneine ILikeIneine self-assigned this Oct 27, 2025
@ILikeIneine ILikeIneine marked this pull request as draft October 27, 2025 08:03
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the flashinfer_mla attention backend. The changes include adding the selection logic for this new backend in platform.py and implementing the backend itself in a new file vllm_metax/v1/attention/backends/mla/flashinfer_mla.py.

My review has identified several critical issues:

  • The backend selection logic in platform.py has a bug that makes flashinfer_mla unreachable as a default backend.
  • The new flashinfer_mla.py implementation contains several bugs, including incorrect tensor slicing, usage of an undefined attribute, and incorrect parameters being passed to the plan method of the flashinfer wrapper, which would lead to runtime errors and incorrect attention calculations.

I have provided detailed comments and suggestions to fix these issues.

Comment thread vllm_metax/platform.py
Comment on lines +257 to +270
if use_flashinfermla:
if use_v1:
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
logger.info_once(
"Using FlashInfer MLA backend on V1 engine.")
return ("vllm_metax.v1.attention.backends.mla."
"flashinfer_mla.MacaFlashInferMLABackend")
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This if use_flashinfermla: block is unreachable when selected_backend is None. This is because use_triton on line 232 will be True, and the check for it on line 254 will execute and return first. This prevents flashinfer_mla from being selected as a default backend.

To fix this, the attention backend selection logic should be mutually exclusive. Consider reordering the if statements to check for use_flashinfermla before use_triton, or using an if/elif/else structure for clarity.

Comment on lines +208 to +209
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + seq_lens_cpu]
paged_kv_len_arr = self.paged_kv_len_arr[:seq_lens_cpu]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Slicing self.paged_kv_indptr_cpu and self.paged_kv_len_arr with seq_lens_cpu (which is a tensor) is incorrect and will raise a TypeError. You should use num_decode_tokens (an integer) for slicing. Additionally, paged_kv_len_arr should be sliced from self.paged_kv_len_arr_cpu, not self.paged_kv_len_arr.

Suggested change
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + seq_lens_cpu]
paged_kv_len_arr = self.paged_kv_len_arr[:seq_lens_cpu]
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_decode_tokens]
paged_kv_len_arr = self.paged_kv_len_arr_cpu[:num_decode_tokens]

Comment on lines +224 to +225
self.paged_kv_last_page_len_cpu[
num_decode_tokens:num_input_tokens].fill_(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

self.paged_kv_last_page_len_cpu is not defined in this class. This appears to be a copy-paste error from another backend implementation and will raise an AttributeError at runtime. This logic seems unnecessary for flashinfer MLA and should be removed.

Comment on lines +238 to +242
head_dim_ckv=self.num_kv_heads,
head_dim_kpe=self.mla_dims.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=1.0, # TODO(Hank) dummy value for testing
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues with the parameters passed to decode_metadata.decode_wrapper.plan:

  1. head_dim_ckv is incorrectly set to self.num_kv_heads (which is 1 for MLA). It should be the dimension of the compressed KV cache, i.e., self.mla_dims.kv_lora_rank.
  2. sm_scale is set to a dummy value of 1.0, which will produce incorrect attention scores. It should be calculated based on the head dimension for the decode path (kv_lora_rank + qk_rope_head_dim).
Suggested change
head_dim_ckv=self.num_kv_heads,
head_dim_kpe=self.mla_dims.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=1.0, # TODO(Hank) dummy value for testing
head_dim_ckv=self.mla_dims.kv_lora_rank,
head_dim_kpe=self.mla_dims.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=1.0 / (self.mla_dims.kv_lora_rank + self.mla_dims.qk_rope_head_dim)**0.5,

Comment thread vllm_metax/platform.py
Comment on lines +266 to +269
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When use_flashinfermla is true but use_v1 is false, the current implementation only logs a warning and then falls through to the default backend logic. This can be confusing, as it may trigger a misleading 'Selected MLA backend is not valid' warning. If a user explicitly selects FLASHINFER_MLA with a V0 engine, it should be a hard error. If it's a default selection, it should silently try the next available backend.

Suggested change
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
else:
if selected_backend == _Backend.FLASHINFER_MLA:
raise ValueError(
"FlashInfer MLA backend is only supported on V1 engine"
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant