Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion vllm_metax/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
] or (selected_backend is None and is_flashmla_supported()[0])
use_triton = selected_backend == _Backend.TRITON_MLA or (
selected_backend is None)
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and block_size in [32, 64])

def _get_version(name, import_suffix) -> str:
if use_v1:
Expand All @@ -252,6 +254,20 @@ def _get_version(name, import_suffix) -> str:
if use_triton:
return _get_version("Maca Triton MLA",
"triton_mla.MacaTritonMLABackend")
if use_flashinfermla:
if use_v1:
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
logger.info_once(
"Using FlashInfer MLA backend on V1 engine.")
return ("vllm_metax.v1.attention.backends.mla."
"flashinfer_mla.MacaFlashInferMLABackend")
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
Comment on lines +266 to +269
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When use_flashinfermla is true but use_v1 is false, the current implementation only logs a warning and then falls through to the default backend logic. This can be confusing, as it may trigger a misleading 'Selected MLA backend is not valid' warning. If a user explicitly selects FLASHINFER_MLA with a V0 engine, it should be a hard error. If it's a default selection, it should silently try the next available backend.

Suggested change
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
else:
if selected_backend == _Backend.FLASHINFER_MLA:
raise ValueError(
"FlashInfer MLA backend is only supported on V1 engine"
)


Comment on lines +257 to +270
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This if use_flashinfermla: block is unreachable when selected_backend is None. This is because use_triton on line 232 will be True, and the check for it on line 254 will execute and return first. This prevents flashinfer_mla from being selected as a default backend.

To fix this, the attention backend selection logic should be mutually exclusive. Consider reordering the if statements to check for use_flashinfermla before use_triton, or using an if/elif/else structure for clarity.

# default mla
logger.warning(
"Selected MLA backend is not valid, falling back to Triton MLA."
Expand Down Expand Up @@ -398,7 +414,7 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
if cls.is_device_capability(100):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
from vllm_metax.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
return supported

Expand Down
313 changes: 313 additions & 0 deletions vllm_metax/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from typing import ClassVar, Optional, Union

from flashinfer import BatchDecodeWithPagedKVCacheWrapper
import torch
from flashinfer.mla import BatchMLAPagedAttentionWrapper

from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm_metax.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl,
MLACommonMetadata, MLACommonMetadataBuilder)
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.utils import cdiv, is_pin_memory_available
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, split_decodes_and_prefills)

# yapf: enable
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024


class MacaFlashInferMLABackend(MLACommonBackend):

@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"

@staticmethod
def get_metadata_cls() -> type["FlashInferMLAMetadata"]:
return FlashInferMLAMetadata

@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder

@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl


@dataclass
class FlashInferMLADecodeMetadata(MLACommonDecodeMetadata):
decode_wrapper: Optional[BatchMLAPagedAttentionWrapper] = None
qo_indptr_gpu: Optional[torch.Tensor] = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None


@dataclass
class FlashInferMLAMetadata(MLACommonMetadata[FlashInferMLADecodeMetadata]):
pass


class FlashInferMLAMetadataBuilder(
MLACommonMetadataBuilder[FlashInferMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: int = 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.compilation_config = vllm_config.compilation_config

self._workspace_buffer = None
self._decode_wrapper = None # Wrapper for decode (general shape)

max_num_pages_per_req = cdiv(self.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req

self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL)
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchMLAPagedAttentionWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)

self.num_qo_heads = self.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
MacaFlashInferMLABackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size

self.cache_dtype = self.cache_config.cache_dtype
# Maca do not support fp8 kv cache
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype

self.q_data_type = self.model_config.dtype

# Preparing persistent buffers (device-side)
self.qo_indptr = torch.arange(0,
max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_len_arr = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)

# host-side buffer
pin_memory = is_pin_memory_available()
self.qo_indptr_cpu = torch.arange(0,
max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indptr_buffer = torch.zeros_like(
self.paged_kv_indptr_cpu, pin_memory=pin_memory)
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_len_arr_cpu = torch.zeros(max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_len_arr_np = (self.paged_kv_len_arr_cpu.numpy())

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
return self._workspace_buffer

def _get_decode_wrapper(self,
batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper

if decode_wrapper is None:
if use_cudagraph:
paged_qo_indptr = self.qo_indptr[:batch_size + 1]
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_len_arr = self.paged_kv_len_arr[:batch_size]
else:
paged_qo_indptr = None
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_len_arr = None
decode_wrapper = BatchMLAPagedAttentionWrapper(
self._get_workspace_buffer(),
use_cuda_graph=use_cudagraph,
qo_indptr=paged_qo_indptr,
kv_indptr=paged_kv_indptr,
kv_indices=paged_kv_indices,
kv_len_arr=paged_kv_len_arr
# Tensor cores are enabled by default because the perf would be
# at least as good as cuda cores for all attention ops in latest
# gpus.
)

# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper

return decode_wrapper

def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashInferMLAMetadata:
decode_metadata = FlashInferMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
)

paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + seq_lens_cpu]
paged_kv_len_arr = self.paged_kv_len_arr[:seq_lens_cpu]
Comment on lines +208 to +209
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Slicing self.paged_kv_indptr_cpu and self.paged_kv_len_arr with seq_lens_cpu (which is a tensor) is incorrect and will raise a TypeError. You should use num_decode_tokens (an integer) for slicing. Additionally, paged_kv_len_arr should be sliced from self.paged_kv_len_arr_cpu, not self.paged_kv_len_arr.

Suggested change
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + seq_lens_cpu]
paged_kv_len_arr = self.paged_kv_len_arr[:seq_lens_cpu]
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_decode_tokens]
paged_kv_len_arr = self.paged_kv_len_arr_cpu[:num_decode_tokens]

use_cudagraph = (self.enable_cuda_graph and num_decode_tokens
<= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decode_tokens))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[1 + num_decode_tokens:1 +
num_input_tokens].fill_(
paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[
num_decode_tokens:num_input_tokens].fill_(1)
Comment on lines +224 to +225
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

self.paged_kv_last_page_len_cpu is not defined in this class. This appears to be a copy-paste error from another backend implementation and will raise an AttributeError at runtime. This logic seems unnecessary for flashinfer MLA and should be removed.

else:
num_input_tokens = num_decode_tokens

decode_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph=use_cudagraph)

decode_metadata.decode_wrapper.plan(
qo_indptr=self.qo_indptr_cpu,
kv_indptr=self.paged_kv_indptr_cpu[:num_input_tokens + 1],
kv_indices=self.paged_kv_indices,
kv_len_arr=self.paged_kv_len_arr_cpu[:num_input_tokens],
num_heads=self.num_qo_heads,
head_dim_ckv=self.num_kv_heads,
head_dim_kpe=self.mla_dims.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=1.0, # TODO(Hank) dummy value for testing
Comment on lines +238 to +242
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues with the parameters passed to decode_metadata.decode_wrapper.plan:

  1. head_dim_ckv is incorrectly set to self.num_kv_heads (which is 1 for MLA). It should be the dimension of the compressed KV cache, i.e., self.mla_dims.kv_lora_rank.
  2. sm_scale is set to a dummy value of 1.0, which will produce incorrect attention scores. It should be calculated based on the head dimension for the decode path (kv_lora_rank + qk_rope_head_dim).
Suggested change
head_dim_ckv=self.num_kv_heads,
head_dim_kpe=self.mla_dims.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=1.0, # TODO(Hank) dummy value for testing
head_dim_ckv=self.mla_dims.kv_lora_rank,
head_dim_kpe=self.mla_dims.qk_rope_head_dim,
page_size=self.page_size,
causal=False,
sm_scale=1.0 / (self.mla_dims.kv_lora_rank + self.mla_dims.qk_rope_head_dim)**0.5,

q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype)


class FlashInferMLAImpl(MLACommonImpl[FlashInferMLAMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)

unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashInferMLA V1 with FP8 KV cache not yet supported")

def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashInferMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if isinstance(q, tuple):
q_nope, q_pe = q
q = torch.cat([q_nope, q_pe], dim=-1)

# Initialize the MLA wrapper
mla_wrapper = attn_metadata.decode.decode_wrapper
head_dim_ckv = q_nope.shape[-1]
head_dim_kpe = q_pe.shape[-1]

# Run the MLA computation
o = mla_wrapper.run(
q_nope=q_nope,
q_pe=q_pe,
ckv_cache=kv_c_and_k_pe_cache[:, :, :head_dim_ckv],
kpe_cache=kv_c_and_k_pe_cache[:, :, head_dim_ckv:head_dim_ckv +
head_dim_kpe],
return_lse=False)

# Return the output tensor and None for LSE (pending support)
return o, None