[WIP] support flashinfer_mla#113
Conversation
Signed-off-by: Hank <hcc.mayday@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for the flashinfer_mla attention backend. The changes include adding the selection logic for this new backend in platform.py and implementing the backend itself in a new file vllm_metax/v1/attention/backends/mla/flashinfer_mla.py.
My review has identified several critical issues:
- The backend selection logic in
platform.pyhas a bug that makesflashinfer_mlaunreachable as a default backend. - The new
flashinfer_mla.pyimplementation contains several bugs, including incorrect tensor slicing, usage of an undefined attribute, and incorrect parameters being passed to theplanmethod of the flashinfer wrapper, which would lead to runtime errors and incorrect attention calculations.
I have provided detailed comments and suggestions to fix these issues.
| if use_flashinfermla: | ||
| if use_v1: | ||
| from vllm.v1.attention.backends.utils import ( | ||
| set_kv_cache_layout) | ||
| set_kv_cache_layout("HND") | ||
| logger.info_once( | ||
| "Using FlashInfer MLA backend on V1 engine.") | ||
| return ("vllm_metax.v1.attention.backends.mla." | ||
| "flashinfer_mla.MacaFlashInferMLABackend") | ||
| else: | ||
| logger.warning( | ||
| "FlashInfer MLA backend is only supported on V1 engine" | ||
| ) | ||
|
|
There was a problem hiding this comment.
This if use_flashinfermla: block is unreachable when selected_backend is None. This is because use_triton on line 232 will be True, and the check for it on line 254 will execute and return first. This prevents flashinfer_mla from being selected as a default backend.
To fix this, the attention backend selection logic should be mutually exclusive. Consider reordering the if statements to check for use_flashinfermla before use_triton, or using an if/elif/else structure for clarity.
| paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + seq_lens_cpu] | ||
| paged_kv_len_arr = self.paged_kv_len_arr[:seq_lens_cpu] |
There was a problem hiding this comment.
Slicing self.paged_kv_indptr_cpu and self.paged_kv_len_arr with seq_lens_cpu (which is a tensor) is incorrect and will raise a TypeError. You should use num_decode_tokens (an integer) for slicing. Additionally, paged_kv_len_arr should be sliced from self.paged_kv_len_arr_cpu, not self.paged_kv_len_arr.
| paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + seq_lens_cpu] | |
| paged_kv_len_arr = self.paged_kv_len_arr[:seq_lens_cpu] | |
| paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_decode_tokens] | |
| paged_kv_len_arr = self.paged_kv_len_arr_cpu[:num_decode_tokens] |
| self.paged_kv_last_page_len_cpu[ | ||
| num_decode_tokens:num_input_tokens].fill_(1) |
| head_dim_ckv=self.num_kv_heads, | ||
| head_dim_kpe=self.mla_dims.qk_rope_head_dim, | ||
| page_size=self.page_size, | ||
| causal=False, | ||
| sm_scale=1.0, # TODO(Hank) dummy value for testing |
There was a problem hiding this comment.
There are two issues with the parameters passed to decode_metadata.decode_wrapper.plan:
head_dim_ckvis incorrectly set toself.num_kv_heads(which is 1 for MLA). It should be the dimension of the compressed KV cache, i.e.,self.mla_dims.kv_lora_rank.sm_scaleis set to a dummy value of1.0, which will produce incorrect attention scores. It should be calculated based on the head dimension for the decode path (kv_lora_rank + qk_rope_head_dim).
| head_dim_ckv=self.num_kv_heads, | |
| head_dim_kpe=self.mla_dims.qk_rope_head_dim, | |
| page_size=self.page_size, | |
| causal=False, | |
| sm_scale=1.0, # TODO(Hank) dummy value for testing | |
| head_dim_ckv=self.mla_dims.kv_lora_rank, | |
| head_dim_kpe=self.mla_dims.qk_rope_head_dim, | |
| page_size=self.page_size, | |
| causal=False, | |
| sm_scale=1.0 / (self.mla_dims.kv_lora_rank + self.mla_dims.qk_rope_head_dim)**0.5, |
| else: | ||
| logger.warning( | ||
| "FlashInfer MLA backend is only supported on V1 engine" | ||
| ) |
There was a problem hiding this comment.
When use_flashinfermla is true but use_v1 is false, the current implementation only logs a warning and then falls through to the default backend logic. This can be confusing, as it may trigger a misleading 'Selected MLA backend is not valid' warning. If a user explicitly selects FLASHINFER_MLA with a V0 engine, it should be a hard error. If it's a default selection, it should silently try the next available backend.
| else: | |
| logger.warning( | |
| "FlashInfer MLA backend is only supported on V1 engine" | |
| ) | |
| else: | |
| if selected_backend == _Backend.FLASHINFER_MLA: | |
| raise ValueError( | |
| "FlashInfer MLA backend is only supported on V1 engine" | |
| ) |
Purpose
This pr is for supporting flashinfer_mla on flashinfer+metax
Test Plan
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.