diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 2c823db4..bd7909b6 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -11,6 +11,106 @@ logger = logging.getLogger(__name__) +def _override_flashmla_sparse_backend(): + """强制覆盖 FLASHMLA_SPARSE backend 指向我们的自定义实现""" + try: + from vllm.attention.backends.registry import register_backend, AttentionBackendEnum + + # 重新注册 FLASHMLA_SPARSE,覆盖默认的 vLLM 原生实现 + register_backend( + AttentionBackendEnum.FLASHMLA_SPARSE, + class_path="vllm_fl.v1.attention.backends.mla.flashmla_sparse.MacaFlashMLASparseBackend", + ) + + # 同时注册 FLASHMLA + register_backend( + AttentionBackendEnum.FLASHMLA, + class_path="vllm_fl.v1.attention.backends.mla.flashmla.MacaFlashMLABackend", + ) + + print("[vllm_fl] Successfully overridden FLASHMLA_SPARSE backend to use vllm_fl implementation") + except Exception as e: + print(f"[vllm_fl] Warning: Failed to override backend: {e}") + +# 在模块导入时立即执行 +_override_flashmla_sparse_backend() + +########### platform plugin ########### +def register(): + """Register the FL platform.""" + _patch_transformers_compat() + + # Model-specific platform patches + from vllm_fl.patches.glm_moe_dsa import apply_platform_patches as glm5_platform + glm5_platform() + + multiproc_method = os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + if multiproc_method is None: + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + _get_op_config() + + # Check if we're running on MetaX (Maca) platform + # If so, also register MetaX-specific components + if _is_metax_platform(): + logger.info("MetaX platform detected, registering MetaX-specific components") + _register_metax_components() + + return "vllm_fl.device_platform.PlatformFL" + + +def _is_metax_platform() -> bool: + """Detect if running on MetaX (Maca) platform.""" + try: + # Check via vendor name from device info + from vllm_fl.utils import DeviceInfo + device_info = DeviceInfo() + return device_info.vendor_name == "metax" + except Exception: + # Fallback: check environment variable or device properties + try: + import torch + device_name = torch.cuda.get_device_name(0).lower() + return "metax" in device_name or "maca" in device_name + except Exception: + return False + + +def _register_metax_components(): + """ + Register MetaX-specific components from vllm_fl. + This ensures compatibility with vllm_fl functionality. + """ + try: + # Import and call vllm_fl register functions + import vllm_fl + + # Register MetaX ops (includes patches) + try: + vllm_fl.register_ops() + logger.info("Registered MetaX ops") + except Exception as e: + logger.warning(f"Failed to register MetaX ops: {e}") + + # Register MetaX models + try: + vllm_fl.register_model() + logger.info("Registered MetaX models") + except Exception as e: + logger.warning(f"Failed to register MetaX models: {e}") + + # Register MetaX quantization configs + try: + vllm_fl.register_quant_configs() + logger.info("Registered MetaX quantization configs") + except Exception as e: + logger.warning(f"Failed to register MetaX quant configs: {e}") + + except ImportError: + logger.debug("vllm_fl not available, skipping MetaX component registration") + except Exception as e: + logger.warning(f"Error registering MetaX components: {e}") + + def __getattr__(name): if name == "distributed": import importlib @@ -29,19 +129,17 @@ def _patch_transformers_compat(): ) -def register(): - """Register the FL platform.""" - _patch_transformers_compat() - - # Model-specific platform patches - from vllm_fl.patches.glm_moe_dsa import apply_platform_patches as glm5_platform - glm5_platform() - - multiproc_method = os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") - if multiproc_method is None: - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - _get_op_config() - return "vllm_fl.platform.PlatformFL" +def register_ops(): + """Register FL ops.""" + import vllm_fl.ops # noqa: F401 + + # Also register MetaX ops if on MetaX platform + if _is_metax_platform(): + try: + import vllm_fl + vllm_fl.register_ops() + except Exception as e: + logger.debug(f"Could not register MetaX ops: {e}") def register_model(): @@ -116,3 +214,34 @@ def register_model(): ) except Exception as e: logger.error(f"Register GlmMoeDsa model error: {str(e)}") + + # Also register MetaX models if on MetaX platform + if _is_metax_platform(): + try: + import vllm_fl + vllm_fl.register_model() + except Exception as e: + logger.debug(f"Could not register MetaX models: {e}") + + +def register_quant_configs(): + """Register quantization configs.""" + # FL-specific quant configs (if any) can be added here + + # Also register MetaX quant configs if on MetaX platform + if _is_metax_platform(): + try: + import vllm_fl + vllm_fl.register_quant_configs() + except Exception as e: + logger.debug(f"Could not register MetaX quant configs: {e}") + + +# Backward compatibility: collect_env function +def collect_env() -> None: + """Collect environment information.""" + try: + from vllm_fl.collect_env import main as collect_env_main + collect_env_main() + except ImportError: + logger.debug("vllm_fl.collect_env not available") \ No newline at end of file diff --git a/vllm_fl/_custom_ops.py b/vllm_fl/_custom_ops.py new file mode 100644 index 00000000..c6dc2c15 --- /dev/null +++ b/vllm_fl/_custom_ops.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# --------------------------------------------------- +# Note: +# +# Here we only maintain the custom ops that are: +# +# - modified +# - newly added +# +# in vllm_metax compared to vllm. +# +# When *adding* new custom ops, make sure you checked the +# latest vllm/_custom_ops.py first to avoid adding duplicates. +# --------------------------------------------------- +import torch +import vllm.envs as envs + + +def awq_gemm( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: int, + temp_space: torch.Tensor, + dtype_bf16: bool, +) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton + + return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + return torch.ops._C.awq_gemm( + input, qweight, scales, qzeros, split_k_iters, temp_space, dtype_bf16 + ) + + +# awq to gptq 4bit conversion +def awq_to_gptq_4bit(qweight: torch.Tensor) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + return qweight + return torch.ops._C.awq_to_gptq_4bit(qweight) + + +# gptq +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, + group_size: int, + perm_space: torch.Tensor, + temp_space: torch.Tensor, + dtype_bf16: bool, +) -> torch.Tensor: + return torch.ops._C.gptq_gemm( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + use_exllama, + bit, + group_size, + perm_space, + temp_space, + dtype_bf16, + ) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) + + +def fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + tileConfig: int, +) -> None: + torch.ops._moe_C.fused_moe_kernel( + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + mul_routed_weight, + top_k, + tileConfig, + ) + + +# dsv32 +def indexer_k_quant_and_cache( + k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + kv_cache_dtype: str, +) -> None: + if k.dtype in (torch.bfloat16, torch.float16): + torch.ops._C_cache_ops.indexer_k_cache(k, kv_cache, slot_mapping) + else: + torch.ops._C_cache_ops.indexer_k_quant_and_cache( + k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype + ) + + +def cp_gather_indexer_k_quant_cache( + kv_cache: torch.Tensor, + dst_k: torch.Tensor, + dst_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, +) -> None: + if dst_k.dtype in (torch.bfloat16, torch.float16) or dst_scale is None: + torch.ops._C_cache_ops.cp_gather_indexer_k_cache( + kv_cache, dst_k, block_table, cu_seq_lens + ) + else: + torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + ) + + +def top_k_per_row( + logits: torch.Tensor, + row_starts: torch.Tensor, + row_ends: torch.Tensor, + topk_indices: torch.Tensor, + num_rows: int, +) -> None: + torch.ops._C.top_k_per_row( + logits, + row_starts, + row_ends, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + ) + + +def top_k_per_row_decode( + logits: torch.Tensor, + next_n: int, + seq_lens: torch.Tensor, + topk_indices: torch.Tensor, + num_rows: int, +) -> None: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + ) diff --git a/vllm_fl/attention/fl_utils.py b/vllm_fl/attention/fl_utils.py new file mode 100644 index 00000000..642e7800 --- /dev/null +++ b/vllm_fl/attention/fl_utils.py @@ -0,0 +1,37 @@ +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def patch_mm_encoder_attention(): + """ + Patch vllm.attention.layers.mm_encoder_attention.maybe_get_vit_flash_attn_backend + to support OOT platforms. + + The original implementation imports flash_attn_varlen_func from fa_utils, + which may not have it defined for OOT platforms. This patch changes the + FLASH_ATTN branch to import directly from vllm.vllm_flash_attn with a + fallback to flash_attn. + """ + import vllm.attention.layers.mm_encoder_attention as mm_mod + from vllm.attention.backends.registry import AttentionBackendEnum + + def _patched_maybe_get_vit_flash_attn_backend(attn_backend): + if attn_backend == AttentionBackendEnum.FLASH_ATTN: + try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + + logger.info_once("Using vllm.vllm_flash_attn for vit attention") + except (ImportError, ModuleNotFoundError): + from flash_attn import flash_attn_varlen_func + + logger.info_once("Using flash_attn for vit attention") + return flash_attn_varlen_func + elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + + return flash_attn_varlen_func + else: + return None + + mm_mod.maybe_get_vit_flash_attn_backend = _patched_maybe_get_vit_flash_attn_backend diff --git a/vllm_fl/attention/ops/__init__.py b/vllm_fl/attention/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/attention/ops/flashmla.py b/vllm_fl/attention/ops/flashmla.py new file mode 100644 index 00000000..06d269ad --- /dev/null +++ b/vllm_fl/attention/ops/flashmla.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +# /------------------------ Metax Modification -------------------------\ +if current_platform.is_out_of_tree(): + try: + import flash_mla # noqa: F401 + + _flashmla_AVAILABLE = True + except ImportError: + _flashmla_AVAILABLE = False +else: + _flashmla_AVAILABLE = False +# \------------------------ Metax Modification -------------------------/ + + +def _is_flashmla_available() -> tuple[bool, str | None]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + if not _flashmla_AVAILABLE: + return False, "flash_mla is not available" + return True, None + + +def is_flashmla_dense_supported() -> tuple[bool, str | None]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_availble, maybe_reason = _is_flashmla_available() + if not is_availble: + return False, maybe_reason + return True, None + + +def is_flashmla_sparse_supported() -> tuple[bool, str | None]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_available, maybe_reason = _is_flashmla_available() + if not is_available: + return False, maybe_reason + return True, None + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: int | None = None, + is_fp8_kvcache: bool = False, + topk: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + - cache_seqlens: (batch_size), dtype torch.int32. + - num_q_tokens_per_head_k: + Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + - num_heads_k: The number of k heads. + - num_heads_q: + The number of q heads. + This argument is optional when sparse attention is not enabled + - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + - topk: If not None, sparse attention will be enabled, + and only tokens in the `indices` array + passed to `flash_mla_with_kvcache_sm90` will be attended to. + + Returns: + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + - num_splits: (batch_size + 1), dtype torch.int32. + """ + # /------------------------ Metax Modification -------------------------\ + return flash_mla.flash_mla_interface.get_mla_metadata( + cache_seqlens, num_q_tokens_per_head_k, num_heads_k + ) + # \------------------------- Metax Modification -------------------------/ + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: float | None = None, + causal: bool = False, + descale_q: torch.Tensor | None = None, + descale_k: torch.Tensor | None = None, + is_fp8_kvcache: bool = False, + indices: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + - q: (batch_size, seq_len_q, num_heads_q, head_dim). + - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + - cache_seqlens: (batch_size), torch.int32. + - head_dim_v: Head dimension of v. + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, + returned by get_mla_metadata. + - num_splits: + (batch_size + 1), torch.int32, returned by get_mla_metadata. + - softmax_scale: float. + The scale of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + - causal: bool. Whether to apply causal attention mask. + - descale_q: (batch_size), + torch.float32. Descaling factors for Q, used for fp8 quantization. + - descale_k: (batch_size), + torch.float32. Descaling factors for K, used for fp8 quantization. + - is_fp8_kvcache: bool. + Whether the k_cache and v_cache are in fp8 format. + For the format of FP8 KV cache, please refer to README.md + - indices: (batch_size, seq_len_q, topk), torch.int32. + If not None, sparse attention will be enabled, + and only tokens in the `indices` array will be attended to. + Invalid indices should be set to -1 or numbers >= total_seq_len_kv. + For details about how to set up `indices`, please refer to README.md. + + Returns: + - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if indices is not None: + # NOTE (zyongye): sparse attention is also causal + # since it only attend to the tokens before + # but here `causal` should not be specified + assert not causal, "causal must be `false` if sparse attention is enabled." + assert (descale_q is None) == (descale_k is None), ( + "descale_q and descale_k should be both None or both not None" + ) + + # /------------------------ Metax Modification -------------------------\ + if indices is None and q.element_size() == 1: + raise NotImplementedError("flash_mla_with_kvcache does not support fp8 input. ") + else: + out, softmax_lse = flash_mla.flash_mla_interface.flash_mla_with_kvcache( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_scheduler_metadata, + num_splits, + softmax_scale, + causal, + ) + # \------------------------- Metax Modification -------------------------/ + # Note(hc): need revisit when we support DCP with decode query_len > 1. + return out, softmax_lse + + +# Metax: torch_ref +def torch_flash_mla_sparse_prefill( + q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import math + + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + assert len(q.shape) == len(kv.shape) == 3 # b == 1 + s_q, _, d_qk = q.shape + s_kv, _, _ = kv.shape + + indices = indices[:, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= s_kv) + qs = q[:, :, :].float() # [s_q, h_q, d_qk] + kvs = kv[:, 0, :].float() # [s_kv, d_qk] + + _, topk = indices.shape + + kvs = torch.index_select( + kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten() + ).view(s_q, topk, d_qk) # [s_q, topk, d_qk] + attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf")) + attn_score *= sm_scale * math.log2(math.e) + max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score @ kvs[:, :, :512] + + return (result.to(torch.bfloat16), max_logits, lse) + + +def flash_mla_sparse_prefill( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + - q: [s_q, h_q, d_qk], bfloat16 + - kv: [s_kv, h_kv, d_qk], bfloat16 + - indices: [s_q, h_kv, topk], int32. + Invalid indices should be set to -1 or numbers >= s_kv + - sm_scale: float + - d_v: The dimension of value vectors. Can only be 512 + + Returns: + - (output, max_logits, lse) + About the definition of output, + max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + # TODO: MetaX flash_mla support + # /------------------------ Metax Modification -------------------------\ + is_all_indices_valid = not (indices == -1).any() + + results = flash_mla.flash_mla_interface.flash_mla_sparse_fwd( + q, kv, indices, sm_scale, d_v, is_all_indices_valid + ) + # \------------------------- Metax Modification -------------------------/ + return results + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# diff --git a/vllm_fl/attention/ops/merge_attn_states.py b/vllm_fl/attention/ops/merge_attn_states.py new file mode 100644 index 00000000..446772c2 --- /dev/null +++ b/vllm_fl/attention/ops/merge_attn_states.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.platforms import current_platform + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: torch.Tensor | None = None, +) -> None: + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel + # is not support for FP8 dtype, fallback to use Triton kernel. + def supported_dtypes(o: torch.Tensor) -> bool: + return o.dtype in [torch.float32, torch.half, torch.bfloat16] + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA + # kernel load/store 128b(16 bytes) per memory issue within + # thread. Namely, the headsize(headdim) must be multiple of + # pack_size (float32 -> 4, half/bfloat16 -> 8). + def supported_headdim(o: torch.Tensor) -> bool: + headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + if o.dtype == torch.float32: + return headdim % 4 == 0 + return headdim % 8 == 0 + + # /------------------------ Metax Modification -------------------------\ + if ( + current_platform.is_out_of_tree() + and supported_dtypes(output) + and supported_headdim(output) + ): + # \------------------------ Metax Modification -------------------------/ + from vllm._custom_ops import merge_attn_states + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) + else: + from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) diff --git a/vllm_fl/attention/ops/triton_decode_attention.py b/vllm_fl/attention/ops/triton_decode_attention.py new file mode 100644 index 00000000..7eae3728 --- /dev/null +++ b/vllm_fl/attention/ops/triton_decode_attention.py @@ -0,0 +1,722 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +# Changes: +# - Add support for page size >= 1. + +# Copyright 2025 vLLM Team +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size >= 1. +""" + +import logging + +from packaging import version + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +is_hip_ = current_platform.is_rocm() +# /------------------------ Metax Modification -------------------------\ +is_maca_ = True +# \------------------------ Metax Modification -------------------------/ + +logger = logging.getLogger(__name__) + +# Only print the following warnings when triton version < 3.2.0. +# The issue won't affect performance or accuracy. +if version.parse(triton.__version__) < version.parse("3.2.0"): + logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored." + ) + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + # /------------------------ Metax Modification -------------------------\ + BLOCK = 64 if not is_maca_ else 8 + # \------------------------- Metax Modification -------------------------/ + + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = q.shape[0], q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + num_warps = 4 + if kv_group_num != 1: + # /------------------------ Metax Modification -------------------------\ + num_warps = 1 if is_maca_ else 2 + # \------------------------- Metax Modification -------------------------/ + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + # /------------------------- Metax Modification -------------------------\ + if is_maca_: + BLOCK = 16 + # \------------------------- Metax Modification -------------------------/ + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + num_stages = 2 + if is_maca_: + # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"scenario": "mla"} + num_stages = 1 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + lse, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + stride_lse_bs, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + lse_val = e_max + tl.log(e_sum) + tl.store( + lse + cur_batch * stride_lse_bs + cur_head, + lse_val, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + lse, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + lse, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + lse.stride(0), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + lse, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + lse, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + lse, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size=1, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[-2] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + lse, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + lse, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) diff --git a/vllm_fl/attention/ops/triton_unified_attention.py b/vllm_fl/attention/ops/triton_unified_attention.py new file mode 100644 index 00000000..3d73c73f --- /dev/null +++ b/vllm_fl/attention/ops/triton_unified_attention.py @@ -0,0 +1,1048 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +logger = init_logger(__name__) +float8_info = torch.finfo(current_platform.fp8_dtype()) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left = 0 # Metax Modification + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + USE_MM_PREFIX: tl.constexpr, # bool + MAX_MM_RANGES: tl.constexpr, # int + mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + if not USE_SINKS: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # ---- Sliding-window tile pruning -------------------- + # Default: keep previous global behavior + tile_start = 0 + tile_end = num_tiles + if SLIDING_WINDOW > 0: + # Query rows covered by this Q-block + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + # For sliding window, each query position q can only attend to + # keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs] + # where q_abs = context_len + q + # The union of allowed key positions for this Q-block is: + # [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi] + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + # Convert to tile indices and clamp + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) + + # iterate through tiles (now limited to the sliding window range) + for j in range(tile_start, tile_end): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + # Compute attention mask: causal by default (key <= query) + query_abs_pos = context_len + query_pos[:, None] + seq_mask = seq_offset[None, :] <= query_abs_pos + + # Apply sliding window to base mask BEFORE mm_prefix OR. + # Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix + if SLIDING_WINDOW > 0: + seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW) + + # PrefixLM: extend mask with bidirectional ranges for multimodal tokens. + # Applied AFTER sliding window so mm_prefix ranges override SW restriction. + if USE_MM_PREFIX: + for i in range(MAX_MM_RANGES): + range_start = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + ) + range_end = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1 + ) + + is_valid = range_start < range_end + q_in_range = ( + (query_abs_pos >= range_start) + & (query_abs_pos <= range_end) + & is_valid + ) + k_in_range = ( + (seq_offset[None, :] >= range_start) + & (seq_offset[None, :] <= range_end) + & is_valid + ) + seq_mask |= q_in_range & k_in_range + + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + S += qq_bias + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size_padded] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_MM_PREFIX: tl.constexpr, # bool + MAX_MM_RANGES: tl.constexpr, # int + mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + if USE_SINKS: + if segm_idx == 0: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), + ): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + # Compute attention mask: causal by default (key <= query) + query_abs_pos = context_len + query_pos[:, None] + seq_mask = seq_offset[None, :] <= query_abs_pos + + # Apply sliding window to base mask BEFORE mm_prefix OR. + # Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix + if SLIDING_WINDOW > 0: + seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW) + + # PrefixLM: extend mask with bidirectional ranges for multimodal tokens. + # Applied AFTER sliding window so mm_prefix ranges override SW restriction. + if USE_MM_PREFIX: + for i in range(MAX_MM_RANGES): + range_start = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + ) + range_end = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1 + ) + + is_valid = range_start < range_end + q_in_range = ( + (query_abs_pos >= range_start) + & (query_abs_pos <= range_end) + & is_valid + ) + k_in_range = ( + (seq_offset[None, :] >= range_start) + & (seq_offset[None, :] <= range_end) + & is_valid + ) + seq_mask |= q_in_range & k_in_range + + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + S += qq_bias + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + # [num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + TILE_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) + + # load segment maxima + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + + # write result + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool: + """Detect Gemma3 models via unique (head_size, sliding_window) signature. + + Gemma3 models are the only ones using sliding_window=1024 with + head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use + different window sizes (Mistral=4096, Phi-3=2047). + """ + return sliding_window == 1024 and head_size in (128, 256) + + +def _get_tile_size( + head_size: int, + sliding_window: int, + element_size: int, + is_prefill: bool, +) -> int: + """Select tile size with Gemma3-specific optimization. + + For Gemma3, use 32 for both prefill and decode to better utilize + the larger head dimension (128/256). For other models, use + the default vLLM behavior. + """ + if _is_gemma3_attention(head_size, sliding_window): + # Gemma3: use 32 for decode (default is 16) + return 32 + + # Default behavior + if is_prefill: + return 32 + return 16 if element_size >= 2 else 32 + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + seq_threshold_3D=None, + num_par_softmax_segments=None, + softmax_segm_output=None, + softmax_segm_max=None, + softmax_segm_expsum=None, + alibi_slopes=None, + output_scale=None, + qq_bias=None, + # Optional tensor for sinks + sinks=None, + # Optional tensor for prefix lengths (PrefixLM support) + mm_prefix_range=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + if sinks is not None: + assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" + + use_mm_prefix = False + max_mm_ranges = 0 + if mm_prefix_range is not None: + if mm_prefix_range.ndim == 3: + use_mm_prefix = True + max_mm_ranges = mm_prefix_range.shape[1] + else: + raise ValueError( + f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}" + ) + + use_alibi_slopes = alibi_slopes is not None + use_qq_bias = qq_bias is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + # Tile sizes for prefill and decode. Gemma3 models use optimized values. + # Note: tile size must be at least 32 for fp8 (element_size == 1). + sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0 + TILE_SIZE_PREFILL = _get_tile_size( + head_size, + sliding_window_val, + q.element_size(), + is_prefill=True, + ) + TILE_SIZE_DECODE = _get_tile_size( + head_size, + sliding_window_val, + q.element_size(), + is_prefill=False, + ) + + # Launch the 2D kernel if + # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or + # 2. The batch includes at least one prefill request, or + # 3. The number of sequences exceeds the configured threshold + if ( + seq_threshold_3D is None + or num_par_softmax_segments is None + or softmax_segm_output is None + or softmax_segm_max is None + or softmax_segm_expsum is None + or max_seqlen_q > 1 + or num_seqs > seq_threshold_3D + ): + kernel_unified_attention_2d[ + ( + total_num_q_blocks, + num_kv_heads, + ) + ]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + USE_MM_PREFIX=use_mm_prefix, + MAX_MM_RANGES=max_mm_ranges, + mm_prefix_range_ptr=mm_prefix_range, + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + USE_FP8=output_scale is not None, + ) + else: + kernel_unified_attention_3d[ + (total_num_q_blocks, num_kv_heads, num_par_softmax_segments) + ]( + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + USE_MM_PREFIX=use_mm_prefix, + MAX_MM_RANGES=max_mm_ranges, + mm_prefix_range_ptr=mm_prefix_range, + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, + ) + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, + USE_FP8=output_scale is not None, + ) diff --git a/vllm_fl/attention/ops/vit_attn_wrappers.py b/vllm_fl/attention/ops/vit_attn_wrappers.py new file mode 100644 index 00000000..7e8a660b --- /dev/null +++ b/vllm_fl/attention/ops/vit_attn_wrappers.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains ops for ViT attention to be compatible with torch.compile +as there are operations here not supported by torch.compile (for instance, +`.item()` in flash attention) + +Using these ops and wrapping vision blocks with `torch.compile` can speed up +throughput in vision models by ~5% relative on H100, and improve token +latencies by ~7% (see qwen2_5_vl for example usage) + +To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0) +""" + +# ----------------------------------- +# Use Maca version of flash attention + +import einops +import torch + +from vllm.utils.torch_utils import direct_register_custom_op + + +def flash_attn_maxseqlen_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, + is_rocm_aiter: bool, +) -> torch.Tensor: + # ----------------------------------- + # Meca version of flash attention + from vllm_fl.attention.utils.fa_utils import flash_attn_varlen_func + + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen.item(), + max_seqlen_k=max_seqlen.item(), + dropout_p=0.0, + causal=False, + ) + context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) + return context_layer + + +def flash_attn_maxseqlen_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, + is_rocm_aiter: bool, +) -> torch.Tensor: + return torch.empty_like(q) + + +direct_register_custom_op( + op_name="maca_flash_attn_maxseqlen_wrapper", + op_func=flash_attn_maxseqlen_wrapper, + fake_impl=flash_attn_maxseqlen_wrapper_fake, +) + + +def vit_flash_attn_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + batch_size: int, + is_rocm_aiter: bool, +) -> torch.Tensor: + return torch.ops.vllm.maca_flash_attn_maxseqlen_wrapper( + q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter + ) diff --git a/vllm_fl/attention/utils/__init__.py b/vllm_fl/attention/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/attention/utils/fa_utils.py b/vllm_fl/attention/utils/fa_utils.py new file mode 100644 index 00000000..56d2c787 --- /dev/null +++ b/vllm_fl/attention/utils/fa_utils.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.attention.utils.fa_utils import logger +from vllm.platforms import current_platform + + +if current_platform.is_out_of_tree(): + from vllm import _custom_ops as ops + + get_scheduler_metadata = None + reshape_and_cache_flash = ops.reshape_and_cache_flash + from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # noqa: F401 + + get_scheduler_metadata = None + + +def get_flash_attn_version(requires_alibi: bool = False) -> int | None: + logger.info_once( + "Using Maca version of flash attention, which only supports version 2." + ) + + # Note: In maca this need to be None since + # metax flash_attn api does not have parameter + # for `fa_version`. + return None + + +def flash_attn_supports_fp8() -> bool: + logger.info_once( + "Using Maca version of flash attention, which does not support FP8" + ) + return False + + +def flash_attn_supports_sinks() -> bool: + # maca fa2 supports sinks + return True + + +def flash_attn_supports_mla(): + return False + + +def is_flash_attn_varlen_func_available() -> bool: + return True diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index e674ec90..ed93d173 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -156,7 +156,7 @@ def __call__(self, *args, **kwargs): # and disable gc for the rest of the graphs. stack.enter_context(patch("gc.collect", lambda: None)) stack.enter_context( - patch("vllm_fl.platform.PlatformFL.empty_cache", lambda: None) + patch("vllm_fl.device_platform.PlatformFL.empty_cache", lambda: None) ) set_graph_pool_id(self.graph_pool) diff --git a/vllm_fl/device_platform.py b/vllm_fl/device_platform.py new file mode 100644 index 00000000..0c092d6b --- /dev/null +++ b/vllm_fl/device_platform.py @@ -0,0 +1,647 @@ +# Copyright (c) 2025 BAAI. All rights reserved. +# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/platforms/cuda.py +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import TYPE_CHECKING, Optional, TypeVar +from typing_extensions import ParamSpec + +import torch + +# import custom ops, trigger op registration (CUDA only) +try: + import vllm._C # noqa +except (ImportError, OSError): + pass # NPU or other platforms may not have vllm._C + +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.logger import init_logger +from vllm.platforms import Platform, PlatformEnum +from vllm.platforms.interface import DeviceCapability + +if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig + from vllm.config import VllmConfig + from vllm.config.cache import CacheDType +else: + VllmConfig = None + CacheDType = None + +from vllm_fl.utils import DeviceInfo, get_device_name, get_device_type + +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +dist_backend_dict = { + "npu": "hccl", + "cuda": "nccl", +} + + +# ============================================================================ +# MetaX (Maca) Platform Compatibility Additions +# ============================================================================ + +def _register_metax_attention_backends() -> None: + """ + Register MetaX-specific attention backends. + This mirrors the functionality in platform.py's register_attention_backends(). + """ + try: + # Register FLASHMLA backend + register_backend( + AttentionBackendEnum.FLASHMLA, + class_path="vllm_fl.v1.attention.backends.mla.flashmla.MacaFlashMLABackend", + ) + # Register FLASHMLA_SPARSE backend - THIS IS THE KEY ONE FOR SPARSE ATTENTION + register_backend( + backend=AttentionBackendEnum.FLASHMLA_SPARSE, + class_path="vllm_fl.v1.attention.backends.mla.flashmla_sparse.MacaFlashMLASparseBackend", + ) + # Register TRITON_MLA backend + register_backend( + backend=AttentionBackendEnum.TRITON_MLA, + class_path="vllm_fl.v1.attention.backends.mla.triton_mla.MacaTritonMLABackend", + ) + # Register standard FLASH_ATTN backend + register_backend( + AttentionBackendEnum.FLASH_ATTN, + class_path="vllm_fl.v1.attention.backends.flash_attn.MacaFlashAttentionBackend", + ) + # Register FLASHINFER backend + register_backend( + backend=AttentionBackendEnum.FLASHINFER, + class_path="vllm_fl.v1.attention.backends.flashinfer.MacaFlashInferBackend", + ) + # Register TRITON_ATTN backend + register_backend( + backend=AttentionBackendEnum.TRITON_ATTN, + class_path="vllm_fl.v1.attention.backends.triton_attn.MacaTritonAttentionBackend", + ) + # Register TREE_ATTN backend + register_backend( + backend=AttentionBackendEnum.TREE_ATTN, + class_path="vllm_fl.v1.attention.backends.tree_attn.MacaTreeAttentionBackend", + ) + # Register FLEX_ATTENTION backend + register_backend( + backend=AttentionBackendEnum.FLEX_ATTENTION, + class_path="vllm_fl.v1.attention.backends.flex_attention.MacaFlexAttentionBackend", + ) + logger.info("Successfully registered MetaX attention backends") + except Exception as e: + logger.warning(f"Failed to register some MetaX attention backends: {e}") + + +def _get_metax_backend_priorities(use_mla: bool, use_sparse: bool = False) -> list[AttentionBackendEnum]: + """ + Get backend priorities for MetaX platform. + Mirrors _get_backend_priorities() in platform.py. + """ + if use_mla: + if use_sparse: + # For sparse MLA, prioritize FLASHMLA_SPARSE + return [ + AttentionBackendEnum.FLASHMLA_SPARSE, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.CUTLASS_MLA, + ] + else: + return [ + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TREE_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + ] + + +# ============================================================================ +# End MetaX Compatibility Additions +# ============================================================================ + + +class PlatformFL(Platform): + _enum = PlatformEnum.OOT + device_info = DeviceInfo() + vendor_name = device_info.vendor_name + device_type = get_device_type(vendor_name) + device_name = get_device_name(vendor_name) + dispatch_key = device_info.dispatch_key + torch_device_fn = device_info.torch_device_fn + ray_device_key: str = "GPU" + dist_backend: str = ( + "flagcx" if "FLAGCX_PATH" in os.environ else dist_backend_dict.get(device_name, "nccl") + ) + ### TODO(lms): dispatch device_control_env_var + # device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + def is_cuda_alike(self) -> bool: + """Stateless version of [torch.cuda.is_available][].""" + if self.vendor_name == "iluvatar": + return False + return self.device_type == "cuda" + + def is_cuda(self) -> bool: + """Stateless version of [torch.cuda.is_available][].""" + return self.device_type == "cuda" and self.vendor_name == "nvidia" + + @property + def supported_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16, torch.float32] + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + """ + Check if the dtype is supported by the current platform. + """ + pass + + @classmethod + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: + cls.torch_device_fn.empty_cache() + cls.torch_device_fn.reset_peak_memory_stats(device) + return cls.torch_device_fn.max_memory_allocated(device) + + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + cls.torch_device_fn.set_device(device) + + @classmethod + def empty_cache(cls) -> None: + cls.torch_device_fn.empty_cache() + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return cls.device_name + + ### TODO(lms): change pin_memory depend device + @classmethod + def is_pin_memory_available(cls): + if cls.device_type in ["cuda", "xpu", "npu"]: + return True + return False + + @classmethod + def import_kernels(cls) -> None: + """Import device-specific kernels.""" + logger.info(f"current vendor_name is: {cls.vendor_name}") + if cls.vendor_name == "metax": + # Import MetaX custom kernels + try: + import mcoplib._C # noqa: F401 + logger.info("Successfully imported mcoplib._C") + except ImportError as e: + logger.warning(f"Failed to import mcoplib._C: {e}") + + try: + import mcoplib._moe_C # noqa: F401 + logger.info("Successfully imported mcoplib._moe_C") + except ImportError as e: + logger.warning(f"Failed to import mcoplib._moe_C: {e}") + + try: + import vllm_fl.dispatch.backends.vendor.metax.patches # noqa: F401 + except Exception as e: + logger.warning(f"Failed to import maca patches: {e}") + + # Register attention backends for MetaX + _register_metax_attention_backends() + + @classmethod + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + parallel_config = vllm_config.parallel_config + model_config = vllm_config.model_config + + parallel_config.worker_cls = "vllm_fl.worker.worker.WorkerFL" + + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + # Ascend NPU requires block_size to be a multiple of 128 + # CUDA can use smaller block sizes like 16 + if cls.device_type == "npu": + cache_config.block_size = 128 + logger.info("Setting kv cache block size to 128 for Ascend NPU.") + else: + cache_config.block_size = 16 + + # TODO(lucas): handle this more gracefully + # Note: model_config may be None during testing + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # for models with both attention and mamba, + # and doesn't need to be reinitialized here + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): + # MetaX (Maca) specific configuration + if cls.vendor_name == "metax": + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") + + # Force block_size to 64 for FlashMLA backends on MetaX + if use_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) + elif cache_config.block_size % 64 != 0: + cache_config.block_size = 64 + logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") + + # Disable cascade attention for MetaX + model_config.disable_cascade_attn = True + else: + # Non-MetaX platforms + if cache_config.block_size % 64 != 0: + cache_config.block_size = 64 + logger.info("Forcing kv cache block size to 64 for FlagOSMLA backend.") + + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if compilation_config.compile_sizes is None: + compilation_config.compile_sizes = [] + + if ( + parallel_config.data_parallel_size > 1 + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + logger.info( + "WideEP: Disabling CUDA Graphs since DeepEP high-throughput " + "kernels are optimized for prefill and are incompatible with " + "CUDA Graphs. " + "In order to use CUDA Graphs for decode-optimized workloads, " + "use --all2all-backend with another option, such as " + "deepep_low_latency, pplx, or allgather_reducescatter." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + # -------------------------------------------------------- + # maca specific config updates + if cls.vendor_name == "metax": + if model_config is not None: + model_config.disable_cascade_attn = True + if attention_config := vllm_config.attention_config: + attention_config.use_cudnn_prefill = False + attention_config.use_trtllm_ragged_deepseek_prefill = False + attention_config.use_trtllm_attention = False + attention_config.disable_flashinfer_prefill = True + + # Add custom attention ops for compilation (from platform.py) + if hasattr(compilation_config, '_attention_ops'): + compilation_config._attention_ops.append("vllm::mx_sparse_attn_indexer") + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", + ) -> list[str]: + """Get the attention backend class path using the dispatch mechanism.""" + + # MetaX (Maca) specific handling - use explicit backend registration + if cls.vendor_name == "metax": + return cls._get_metax_attn_backend_cls(selected_backend, attn_selector_config) + + # Default dispatch mechanism for other vendors + from vllm_fl.dispatch import call_op + + use_mla = attn_selector_config.use_mla + use_sparse = getattr(attn_selector_config, 'use_sparse', False) + + backend_path = call_op("attention_backend", use_mla=use_mla, use_sparse=use_sparse) + + logger.info_once( + "Using attention backend via dispatch (use_mla=%s, use_sparse=%s): %s", + use_mla, + use_sparse, + backend_path, + scope="local", + ) + logger.info( + "Using attention backend via dispatch (use_mla=%s): %s" + % (use_mla, backend_path) + ) + return backend_path + + @classmethod + def _get_metax_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", + ) -> str: + """ + MetaX-specific backend selection logic. + Mirrors MacaPlatform.get_attn_backend_cls() in platform.py + """ + use_mla = attn_selector_config.use_mla + use_sparse = getattr(attn_selector_config, 'use_sparse', False) or ( + hasattr(attn_selector_config, 'hf_config') and + hasattr(attn_selector_config.hf_config, 'index_topk') + ) + + # Ensure backends are registered + _register_metax_attention_backends() + + # Remove block_size from config for validation (as in platform.py) + attn_selector_config = attn_selector_config._replace(block_size=None) + + device_capability = cls.get_device_capability() + assert device_capability is not None + + # If a specific backend is selected, validate and use it + if selected_backend is not None: + try: + backend_class = selected_backend.get_class() + invalid_reasons = backend_class.validate_configuration( + device_capability=device_capability, + **attn_selector_config._asdict(), + ) + except ImportError as e: + invalid_reasons = [f"ImportError: {e}"] + + if invalid_reasons: + raise ValueError( + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" + ) + else: + logger.info("Using %s backend.", selected_backend) + return selected_backend.get_path() + + # No selected backend, find the best valid one + valid_backends_priorities = [] + invalid_reasons = {} + + backend_priorities = _get_metax_backend_priorities(use_mla, use_sparse) + + for priority, backend in enumerate(backend_priorities): + try: + backend_class = backend.get_class() + invalid_reasons_i = backend_class.validate_configuration( + device_capability=device_capability, + **attn_selector_config._asdict(), + ) + except ImportError as e: + invalid_reasons_i = [f"ImportError: {e}"] + + if invalid_reasons_i: + invalid_reasons[backend] = invalid_reasons_i + else: + valid_backends_priorities.append((backend, priority)) + + # Log invalid reasons + if invalid_reasons: + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" + ) + config_str = attn_selector_config.__repr__() + logger.info_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + + if len(valid_backends_priorities) == 0: + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" + ) + config_str = attn_selector_config.__repr__() + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." + ) + + # Select the highest priority valid backend + logger.info( + "Valid backends: %s", [b[0].name for b in valid_backends_priorities] + ) + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] + selected_backend = valid_backends_priorities[selected_index][0] + logger.info_once( + "Using %s attention backend out of potential backends: %s", + selected_backend.name, + tuple(b[0].name for b in valid_backends_priorities), + scope="local", + ) + + return selected_backend.get_path() + + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + # MetaX can use the same backends as CUDA + if cls.vendor_name == "metax": + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + ] + return [ + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASH_ATTN, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> list[str]: + # For MetaX, use FLASH_ATTN if available + if cls.vendor_name == "metax": + from vllm_fl.attention.fl_utils import patch_mm_encoder_attention + patch_mm_encoder_attention() + + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + # Try FlashAttention first + try: + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size(head_size) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN + except ImportError: + pass + + logger.error( + "Fallback to Backend TORCH_SDPA as vit_attn_backend since head_size or dtype is " + "not supported on FLASH_ATTN." + ) + return AttentionBackendEnum.TORCH_SDPA + + # Default behavior for other vendors + from vllm_fl.attention.fl_utils import patch_mm_encoder_attention + + patch_mm_encoder_attention() + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + # Try FlashAttention first + if (cc := cls.get_device_capability()) and cc.major >= 8: + try: + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN + except ImportError: + pass + + return AttentionBackendEnum.TORCH_SDPA + + @classmethod + def get_punica_wrapper(cls) -> str: + # TODO(lms): support fl PunicaWrapper + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + if cls.dist_backend == "flagcx": + logger.info("Using CommunicatorFL for communication.") + return "vllm_fl.distributed.communicator.CommunicatorFL" # noqa + else: + logger.info("Using CudaCommunicator for communication.") + return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + @classmethod + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm_fl.compilation.graph.GraphWrapper" + + @classmethod + def support_static_graph_mode(cls) -> bool: + if cls.vendor_name in ["nvidia", "ascend", "metax"]: + return True + return False + + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache device .""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from device to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + + ### NOTE(lms): will effect compile result + @classmethod + def opaque_attention_op(cls) -> bool: + return True + + @classmethod + def use_custom_allreduce(cls) -> bool: + if cls.dist_backend == "flagcx": + return False + return True + + @classmethod + def pre_register_and_update(cls, parser = None) -> None: + if cls.device_name == "npu": + import vllm_fl.dispatch.backends.vendor.ascend + + # Register MetaX backends during pre-registration + if cls.vendor_name == "metax": + _register_metax_attention_backends() + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + # TODO(yxa): For NPU/Ascend devices, return None (no capability version like CUDA) + if cls.device_type == "npu": + return None + # For CUDA devices + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: + try: + import pynvml + + pynvml.nvmlInit() + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [ + pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, + peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK, + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if" + " your machine has no NVLink equipped." + ) + return False + return True + except: + return False \ No newline at end of file diff --git a/vllm_fl/models/deepseek_v2.py b/vllm_fl/models/deepseek_v2.py new file mode 100644 index 00000000..a2dd26c2 --- /dev/null +++ b/vllm_fl/models/deepseek_v2.py @@ -0,0 +1,1714 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV2/DeepseekV3 model.""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice + +import torch +from torch import nn +from transformers import DeepseekV2Config, DeepseekV3Config + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layer import Attention +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm_fl.mx_utils.deep_gemm import bf16_mqa_logits, bf16_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op +from vllm_fl.v1.attention.backends.mla.indexer import ( + MacaDeepseekV32IndexerBackend as DeepseekV32IndexerBackend, + DeepseekV32IndexerMetadata, +) +from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager + +from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops + from vllm_fl import _custom_ops as mx_ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +logger = init_logger(__name__) + + +class DeepseekAttention(nn.Module): + """Normal MHA implementation used by Deepseek v1.""" + + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + num_heads: int, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + rope_parameters=config.rope_parameters, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV2MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + is_sequence_parallel=False, + prefix: str = "", + ) -> None: + super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV2MoE(nn.Module): + def __init__( + self, + config: DeepseekV2Config | DeepseekV3Config, + parallel_config: ParallelConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + if getattr(config, "topk_method", None) == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None + + # Load balancing settings. + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.is_fusion_moe_shared_experts_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled: + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "scoring_func", "softmax"), + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not self.is_rocm_aiter_moe_enabled + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if self.is_fusion_moe_shared_experts_enabled + else None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + if not self.is_rocm_aiter_moe_enabled: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + +class DeepseekV2Attention(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + topk_indices_buffer: torch.Tensor | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + assert topk_indices_buffer is None, ( + "topk_indices_buffer is not \ + supported for DeepseekV2Attention" + ) + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + if config.rope_parameters["rope_type"] != "default": + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + + self.rotary_emb = get_rope( + qk_rope_head_dim, + max_position=max_position_embeddings, + rope_parameters=config.rope_parameters, + is_neox_style=False, + ) + + if ( + config.rope_parameters["rope_type"] != "default" + and config.rope_parameters["rope_type"] == "deepseek_yarn" + ): + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) + scaling_factor = config.rope_parameters["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank :] + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) + attn_output = self.attn(q, k, v) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): + def __init__( + self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig + ): + super().__init__() + self.kv_cache = [torch.tensor([])] + self.head_dim = head_dim + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + return MLAAttentionSpec( # Only has one vector instead of K + V + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + ) + + def forward(self): ... + + def get_attn_backend(self) -> AttentionBackend: + return DeepseekV32IndexerBackend + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_bf16: torch.Tensor, + k_bf16: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + # Reserve workspace for indexer during profiling run + current_workspace_manager().get_simultaneous( + ((total_seq_lens, head_dim), torch.bfloat16), + ) + + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_bf16, + k_bf16, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + mx_ops.indexer_k_quant_and_cache( + k_bf16, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + + # Get the full shared workspace buffers once (will allocate on first use) + workspace_manager = current_workspace_manager() + k_bf16_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), torch.bfloat16), + )[0] + + for chunk in prefill_metadata.chunks: + _k_bf16 = k_bf16_full[: chunk.total_seq_lens] + k_scale = None + mx_ops.cp_gather_indexer_k_quant_cache( + kv_cache, + _k_bf16, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + logits = bf16_mqa_logits( + q_bf16[chunk.token_start : chunk.token_end], + _k_bf16, + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + mx_ops.top_k_per_row( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_bf16_decode_tokens = pack_seq_triton( + q_bf16[:num_decode_tokens], decode_lens + ) + else: + padded_q_bf16_decode_tokens = q_bf16[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_bf16.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_bf16_decode_tokens.shape[0] + next_n = padded_q_bf16_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + logits = bf16_paged_mqa_logits( + padded_q_bf16_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len, + ) + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + + mx_ops.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + ) + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + return topk_indices_buffer + + +direct_register_custom_op( + op_name="mx_sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) + + +class Indexer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + q_lora_rank: int, + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, + topk_indices_buffer: torch.Tensor | None, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = config + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 + self.q_lora_rank = q_lora_rank # 1536 + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.wk = ReplicatedLinear( + hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + ) + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.weights_proj = ReplicatedLinear( + hidden_size, self.n_head, bias=False, quant_config=None, prefix=f"{prefix}.weights_proj" + ) + self.softmax_scale = self.head_dim**-0.5 + + # MetaX use bfloat16 + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer + + # NOTE: (zyongye) we use fp8 naive cache, + # where we store value in fp8 and scale in fp32 + # per self.quant_block_size element + self.k_cache = DeepseekV32IndexerCache( + head_dim=self.head_dim, + dtype=torch.bfloat16, + prefix=f"{prefix}.k_cache", + cache_config=cache_config, + ) + self.max_model_len = vllm_config.model_config.max_model_len + self.prefix = prefix + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + + def forward( + self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb + ) -> torch.Tensor: + q, _ = self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + k, _ = self.wk(hidden_states) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + + + k_nope = k_nope.unsqueeze(1).expand(-1, self.n_head, -1) # [2048, 32, 64] + + k_pe = k_pe.expand(-1, self.n_head, -1) # [2048, 32, 64] + + q = torch.cat([q_pe, q_nope], dim=-1) # [2048, 32, 128] + k = torch.cat([k_pe, k_nope], dim=-1) # [2048, 32, 128] + + + # we only quant q here since k quant is fused with cache insertion + # q = q.view(-1, self.head_dim) + # q_fp8, q_scale = per_token_group_quant_fp8( + # q, + # self.quant_block_size, + # column_major_scales=False, + # use_ue8m0=self.scale_fmt is not None, + # ) + # q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) + # q_scale = q_scale.view(-1, self.n_head, 1) + + q = q.view(-1, self.n_head, self.head_dim) + weights, _ = self.weights_proj(hidden_states) + weights = ( + weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5 + ) + weights = weights.squeeze(-1) + + assert q.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + return torch.ops.vllm.mx_sparse_attn_indexer( + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + + +class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: + vllm/v1/attention/backends/mla/utils.py + """ + + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + topk_indices_buffer: torch.Tensor | None = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.fused_qkv_a_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True, + ) + else: + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + + if self.q_lora_rank is not None: + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + if config.rope_parameters["rope_type"] != "default": + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + + self.rotary_emb = get_rope( + qk_rope_head_dim, + max_position=max_position_embeddings, + rope_parameters=config.rope_parameters, + is_neox_style=False, + ) + + if ( + config.rope_parameters["rope_type"] != "default" + and config.rope_parameters["rope_type"] == "deepseek_yarn" + ): + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) + scaling_factor = config.rope_parameters["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.is_v32 = hasattr(config, "index_topk") + + if self.is_v32: + self.indexer_rope_emb = get_rope( + qk_rope_head_dim, + max_position=max_position_embeddings, + rope_parameters=config.rope_parameters, + is_neox_style=True, + ) + self.indexer = Indexer( + vllm_config, + config, + hidden_size, + q_lora_rank, + quant_config, + cache_config, + topk_indices_buffer, + f"{prefix}.indexer", + ) + else: + self.indexer_rope_emb = None + self.indexer = None + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None + else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=self.indexer, + indexer_rotary_emb=self.indexer_rope_emb, + is_sparse=self.is_v32, + topk_indices_buffer=topk_indices_buffer, + ) + + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, + ) -> torch.Tensor: + return self.mla_attn(positions, hidden_states, llama_4_scaling) + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + config: DeepseekV2Config | None = None, + topk_indices_buffer: torch.Tensor | None = None, + ) -> None: + super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.hidden_size = config.hidden_size + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + moe_layer_freq = getattr(config, "moe_layer_freq", 1) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep=".")[-1]) + self.layer_idx = layer_idx + + # verify MLA attention specific fields + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + v_head_dim = getattr(config, "v_head_dim", 0) + kv_lora_rank = getattr(config, "kv_lora_rank", 0) + use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + self.use_mha = use_mha + + if use_mha: + attn_cls = DeepseekAttention + elif model_config.use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=kv_lora_rank, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, + ) + + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % moe_layer_freq == 0 + ): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + llama_4_scaling: torch.Tensor | None = None, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_kwargs = { + "positions": positions, + "hidden_states": hidden_states, + } + if not self.use_mha: + attn_kwargs["llama_4_scaling"] = llama_4_scaling + hidden_states = self.self_attn(**attn_kwargs) + + if ( + not isinstance(self.self_attn, DeepseekAttention) + and hidden_states.dtype == torch.float16 + ): + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1.0 / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1.0 / self.routed_scaling_factor + + return hidden_states, residual + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.device = current_platform.device_type + + # ----------------------------------------- + # Note: torch compile needs to eval the + # `getattr` ahead of `forward` + self._llama_4_scaling_config = getattr(config, "llama_4_scaling", None) + + self.vocab_size = config.vocab_size + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device=self.device, + ) + else: + topk_indices_buffer = None + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer( + vllm_config, prefix, topk_indices_buffer=topk_indices_buffer + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Compute llama 4 scaling once per forward pass if enabled + # ----------------------------------------------------- + # Note: it need to be a explicit variable to make torch + # compile happy, do not make it like: + # + # llama_4_scaling_config = getattr(config, "llama_4_scaling", None) + # + # torch compile might omit the default value of + # `None` at eval-time. + llama_4_scaling_config = self._llama_4_scaling_config + llama_4_scaling: torch.Tensor | None + if llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=llama_4_scaling_config["beta"], + positions=positions, + ) + else: + llama_4_scaling = None + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV2MixtureOfExperts(MixtureOfExperts): + moe_mlp_layers: list[DeepseekV2MoE] + """ + List of MoE MLP layers in the model. + """ + + def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None): + if example_moe is None: + self.num_moe_layers = 0 + self.num_expert_groups = 0 + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.") + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for moe in self.moe_mlp_layers: + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class DeepseekV2ForCausalLM( + nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle +): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + model_cls = DeepseekV2Model + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + self.use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if self.use_mha: + self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = self.model_cls( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + # Set MoE hyperparameters + self.num_moe_layers = ( + self.config.num_hidden_layers - self.config.first_k_dense_replace + ) + self.set_moe_parameters() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.num_expert_groups = getattr(self.config, "n_group", 1) + + self.moe_layers = [] + self.moe_mlp_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + + self.extract_moe_parameters(example_moe) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + mla_params_mapping = [ + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + if self.use_mha: + stacked_params_mapping.extend(mha_params_mapping) + else: + stacked_params_mapping.extend(mla_params_mapping) + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if rocm_aiter_moe_shared_expert_enabled + else 0 + ), + num_redundant_experts=self.num_redundant_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + is_fusion_moe_shared_experts_layer = ( + rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) + ) + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + if is_fusion_moe_shared_experts_layer: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: + continue + else: + name = name_mapped + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fusion_moe_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" + ) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fusion_moe_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[ + j * chunk_size : (j + 1) * chunk_size, : + ] + else: + weight_to_load = loaded_weight[ + :, j * chunk_size : (j + 1) * chunk_size + ] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fusion_moe_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if not is_fusion_moe_shared_experts_layer: + loaded_params.add(name) + + return loaded_params + + +class DeepseekForCausalLM(DeepseekV2ForCausalLM): + pass + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +# Compatibility with +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py +def get_spec_layer_idx_from_weight_name( + config: DeepseekV2Config | DeepseekV3Config, weight_name: str +) -> int | None: + if ( + hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0 + ): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): + return layer_idx + i + return None \ No newline at end of file diff --git a/vllm_fl/models/glm_moe_dsa.py b/vllm_fl/models/glm_moe_dsa.py index 4b88c4d0..f9788b4d 100644 --- a/vllm_fl/models/glm_moe_dsa.py +++ b/vllm_fl/models/glm_moe_dsa.py @@ -9,7 +9,7 @@ handles MLA, MoE, the DSA Indexer, and MTP speculative decoding layers. """ -from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM +from vllm_fl.models.deepseek_v2 import DeepseekV2ForCausalLM class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM): diff --git a/vllm_fl/mx_utils/__init__.py b/vllm_fl/mx_utils/__init__.py new file mode 100644 index 00000000..e961dabd --- /dev/null +++ b/vllm_fl/mx_utils/__init__.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2024 The vLLM team. + + +def import_pymxsml(): + """ + Historical comments: + + libnvml.so is the library behind nvidia-smi, and + pynvml is a Python wrapper around it. We use it to get GPU + status without initializing CUDA context in the current process. + Historically, there are two packages that provide pynvml: + - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official + wrapper. It is a dependency of vLLM, and is installed when users + install vLLM. It provides a Python module named `pynvml`. + - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. + Prior to version 12.0, it also provides a Python module `pynvml`, + and therefore conflicts with the official one. What's worse, + the module is a Python package, and has higher priority than + the official one which is a standalone Python file. + This causes errors when both of them are installed. + Starting from version 12.0, it migrates to a new module + named `pynvml_utils` to avoid the conflict. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. + """ + import vllm_metax.third_party.pymxsml as pymxsml + + return pymxsml + + +def vllm_version(): + import vllm + + return vllm.__version__ diff --git a/vllm_fl/mx_utils/deep_gemm.py b/vllm_fl/mx_utils/deep_gemm.py new file mode 100644 index 00000000..b1b7f1b7 --- /dev/null +++ b/vllm_fl/mx_utils/deep_gemm.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for DeepGEMM API changes. + +Users of vLLM should always import **only** these wrappers. +""" + +from __future__ import annotations + +import functools +import importlib +import os +from typing import Any, Callable, NoReturn + +import torch + +import vllm.envs as envs +from vllm.utils.import_utils import has_deep_gemm + + +@functools.cache +def is_deep_gemm_supported() -> bool: + """Return `True` if DeepGEMM is supported on the current platform. + Currently, only Hopper and Blackwell GPUs are supported. + """ + return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable DeepGEMM backend.""" + raise RuntimeError( + "DeepGEMM backend is not available. Please install the `deep_gemm` " + "package to enable BF16 kernels." + ) + + +_bf16_mqa_logits_impl: Callable[..., Any] | None = None +_bf16_paged_mqa_logits_impl: Callable[..., Any] | None = None + + +def _lazy_init() -> None: + """Import deep_gemm and resolve symbols on first use.""" + global _bf16_mqa_logits_impl, _bf16_paged_mqa_logits_impl + + if not has_deep_gemm(): + return + + # Set up deep_gemm cache path + DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR" + if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): + os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( + envs.VLLM_CACHE_ROOT, "deep_gemm" + ) + + _dg = importlib.import_module("deep_gemm") + + _bf16_mqa_logits_impl = getattr(_dg, "bf16_mqa_logits", None) + _bf16_paged_mqa_logits_impl = getattr(_dg, "bf16_paged_mqa_logits", None) + + +def bf16_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + _lazy_init() + if _bf16_mqa_logits_impl is None: + return _missing() + return _bf16_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +def bf16_paged_mqa_logits( + q_bf16: torch.Tensor, + kv_cache_bf16: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute BF16 MQA logits using paged KV-cache. + + Args: + q_bf16: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float16` by caller. + kv_cache_bf16: Paged KV-cache in packed BF16+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + _lazy_init() + if _bf16_paged_mqa_logits_impl is None: + return _missing() + return _bf16_paged_mqa_logits_impl( + q_bf16, + kv_cache_bf16, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + + +__all__ = [ + "bf16_mqa_logits", + "bf16_paged_mqa_logits", + "is_deep_gemm_supported", +] diff --git a/vllm_fl/patches/glm_moe_dsa.py b/vllm_fl/patches/glm_moe_dsa.py index 4d205537..9dda83ff 100644 --- a/vllm_fl/patches/glm_moe_dsa.py +++ b/vllm_fl/patches/glm_moe_dsa.py @@ -78,7 +78,7 @@ def patch_fp8_mqa_logits_dim(): Upstream fix: https://github.com/vllm-project/vllm/pull/32652 We wrap fp8_mqa_logits to flatten k_scale before calling the native impl. """ - import vllm.utils.deep_gemm as dg_mod + import vllm_fl.mx_utils.deep_gemm as dg_mod dg_mod._lazy_init() _orig_impl = dg_mod._fp8_mqa_logits_impl diff --git a/vllm_fl/v1/__init__.py b/vllm_fl/v1/__init__.py new file mode 100644 index 00000000..353ab045 --- /dev/null +++ b/vllm_fl/v1/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. diff --git a/vllm_fl/v1/attention/__init__.py b/vllm_fl/v1/attention/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/v1/attention/backends/__init__.py b/vllm_fl/v1/attention/backends/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/v1/attention/backends/flash_attn.py b/vllm_fl/v1/attention/backends/flash_attn.py new file mode 100644 index 00000000..e4a4deb1 --- /dev/null +++ b/vllm_fl/v1/attention/backends/flash_attn.py @@ -0,0 +1,1202 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" + +from dataclasses import dataclass +from typing import ClassVar + +import numpy as np +import torch + +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) +from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs + +# -------------------------------------------------------------- +# Note: use Maca's merge_attn_states to get cuda kernel invoked +# -------------------------------------------------------------- +from vllm_fl.attention.ops.merge_attn_states import merge_attn_states +from vllm_fl.attention.utils.fa_utils import ( + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) + +if is_flash_attn_varlen_func_available(): + from vllm_fl.attention.utils.fa_utils import ( + flash_attn_supports_sinks, + flash_attn_varlen_func, + get_scheduler_metadata, + reshape_and_cache_flash, + flash_attn_with_kvcache, # used for prefill decode split + ) +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config +from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_dcp_local_seq_lens, + get_kv_cache_layout, + split_decodes_and_prefills, # used for prefill decode split + reshape_attn_output_for_spec_decode, # used for prefill decode split with mtp + reshape_query_for_spec_decode, # used for prefill decode split with mtp +) +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.kv_cache_interface import AttentionSpec + +# -------------------------------------------------------------- +# Note: used for prefill decode split with mtp on maca +# -------------------------------------------------------------- +from vllm_fl.v1.attention.backends.mla.common import QueryLenSupport + +logger = init_logger(__name__) + + +@register_backend(AttentionBackendEnum.FLASH_ATTN) +class MacaFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + if ( + model_config + and model_config.is_hybrid + and ( + cache_config.mamba_ssm_cache_dtype == "float32" + or cache_config.mamba_cache_dtype == "float32" + ) + ): + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] + return [MultipleOf(16)] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlashAttention supports all attention types.""" + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + @staticmethod + def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: + raise NotImplementedError( + "FP8 dtype is not supported for FlashAttention on Maca." + ) + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size % 8 == 0 and head_size <= 256 + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in ["auto"] + + @classmethod + def supports_sink(cls) -> bool: + if not is_flash_attn_varlen_func_available(): + return False + return flash_attn_supports_sinks() + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink and device_capability < DeviceCapability(9, 0): + return "sink not supported on compute capability < 9.0" + return None + + +@dataclass +class FlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # /------------------------ Metax Modification -------------------------\ + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + decode_query_start_loc: torch.Tensor + decode_max_seq_len: int + decode_seq_lens: torch.Tensor + decode_block_table: torch.Tensor + + num_prefills: int + num_prefill_tokens: int + prefill_query_start_loc: torch.Tensor + prefill_max_seq_len: int + prefill_seq_lens: torch.Tensor + prefill_block_table: torch.Tensor + # \------------------------- Metax Modification -------------------------/ + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None + + # For GQA DCP + max_dcp_context_kv_len: int | None = None + dcp_context_kv_lens: torch.Tensor | None = None + + # Optional aot scheduling + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None + max_num_splits: int = 0 + + causal: bool = True + + +def _get_sliding_window_configs( + vllm_config: VllmConfig, +) -> set[tuple[int, int] | None]: + """Get the set of all sliding window configs used in the model.""" + sliding_window_configs: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, FlashAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + return sliding_window_configs + + +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]): + # /------------------------ Metax Modification -------------------------\ + _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + + # Defines the level of query length support for this backend. + # - SINGLE_ONLY: Only single-token queries (no spec decode support) + # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) + # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) + # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when + # speculative decoding is enabled. + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + + # The threshold for reordering the batch into decode and prefill requests. + # If > 1, the batch will be reordered such that requests with + # query length <= threshold are classified as decode requests. + # Use `query_len_support` (above) to set this automatically + # when speculative decoding is enabled. + reorder_batch_threshold: int = 128 # process small prefills with decode pathway + # \------------------------- Metax Modification -------------------------/ + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.kv_cache_dtype = kv_cache_spec.dtype + self.headdim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + + self.max_num_splits = 0 # No upper bound on the number of splits. + self.aot_schedule = get_flash_attn_version() == 3 + + # /------------------------ Metax Modification -------------------------\ + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_decode + ) + + # Validate consistency between query_len_support and reorder_batch_threshold + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: + assert self.reorder_batch_threshold == 1, ( + f"reorder_batch_threshold must be 1 when query_len_support is " + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + # \------------------------- Metax Modification -------------------------/ + + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + self.cp_kv_cache_interleave_size = ( + self.parallel_config.cp_kv_cache_interleave_size + ) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size + + if self.use_full_cuda_graph and self.aot_schedule: + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = ( + self.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) + + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: tuple[int, int] | None = None + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: + """ + fast_build disables AOT scheduling, used when there will be few + iterations i.e. spec-decode + """ + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal + + # /------------------------ Metax Modification -------------------------\ + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), + ) + ) + + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + assert num_decodes + num_prefills == num_reqs + # \------------------------- Metax Modification -------------------------/ + + # the overhead of the aot schedule is not worth it for spec-decode + aot_schedule = self.aot_schedule and not fast_build + + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if aot_schedule: + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + aot_schedule = False + + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + # /------------------------ Metax Modification -------------------------\ + # For handling prefill decode split + if num_decodes > 0: + decode_max_seq_len = int( + common_attn_metadata.seq_lens_cpu[:num_decodes].max() + ) + decode_query_start_loc = common_attn_metadata.query_start_loc[ + : num_decodes + 1 + ] + decode_seq_lens = common_attn_metadata.seq_lens[:num_decodes] + decode_block_table_tensor = common_attn_metadata.block_table_tensor[ + :num_decodes + ] + else: + decode_max_seq_len = 0 + decode_query_start_loc = None + decode_seq_lens = None + decode_block_table_tensor = None + + if num_prefills > 0: + prefill_max_seq_len = int( + common_attn_metadata.seq_lens_cpu[num_decodes:num_reqs].max() + ) + prefill_query_start_loc = ( + common_attn_metadata.query_start_loc[num_decodes : num_reqs + 1] + - common_attn_metadata.query_start_loc[num_decodes] + ) + prefill_seq_lens = common_attn_metadata.seq_lens[num_decodes:num_reqs] + prefill_block_table_tensor = common_attn_metadata.block_table_tensor[ + num_decodes:num_reqs + ] + else: + prefill_max_seq_len = 0 + prefill_query_start_loc = None + prefill_seq_lens = None + prefill_block_table_tensor = None + # \------------------------- Metax Modification -------------------------/ + + if vllm_is_batch_invariant(): + max_num_splits = 1 + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + qkv_dtype = MacaFlashAttentionBackend.get_fp8_dtype_for_flashattn( + cache_dtype + ) + else: + qkv_dtype = self.kv_cache_dtype + if aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads_q * self.dcp_world_size, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + cache_seqlens=seqlens, + qkv_dtype=qkv_dtype, + cu_seqlens_q=cu_query_lens, + page_size=self.block_size, + causal=causal, + window_size=self.aot_sliding_window, + num_splits=max_num_splits, + ) + return None + + use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens = query_start_loc[1:] - query_start_loc[:-1] + dcp_context_kv_lens = seq_lens - query_kv_lens + + dcp_context_kv_lens = get_dcp_local_seq_lens( + dcp_context_kv_lens, + self.dcp_world_size, + self.dcp_rank, + self.cp_kv_cache_interleave_size, + ) + # After DCP distribution, the maximum number of tokens for any rank is + # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, + # and I is cp_kv_cache_interleave_size. + # This eliminates GPU->CPU sync while minimizing workspace over-allocation. + num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size + max_dcp_context_kv_len = ( + (max_seq_len + num_partitions - 1) // num_partitions + ) * self.cp_kv_cache_interleave_size + + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + # Use GPU tensor directly - no CPU sync needed + suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) + else: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) + # For FA3 + full cudagraph + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + # /------------------------ Metax Modification -------------------------\ + # For handling prefill decode split + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + decode_query_start_loc=decode_query_start_loc, + decode_max_seq_len=decode_max_seq_len, + decode_seq_lens=decode_seq_lens, + decode_block_table=decode_block_table_tensor, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill_query_start_loc=prefill_query_start_loc, + prefill_max_seq_len=prefill_max_seq_len, + prefill_seq_lens=prefill_seq_lens, + prefill_block_table=prefill_block_table_tensor, + # \------------------------- Metax Modification -------------------------/ + block_table=block_table_tensor, + slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + max_num_splits=max_num_splits, + causal=causal, + ) + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + + +class FlashAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + sinks: torch.Tensor | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() + + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device." + ) + + self.sinks = sinks + if self.sinks is not None: + assert flash_attn_supports_sinks(), ( + "Sinks are only supported in FlashAttention 3" + ) + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer" + ) + + def supports_quant_query_input(self) -> bool: + return False + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for FlashAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + + # For decoder and cross-attention, use KV cache as before + key_cache, value_cache = kv_cache.unbind(0) + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer + dtype = MacaFlashAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype + ) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) + + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + # ┌------------------------ Metax Modification -------------------------┐ + # For handling prefill decode split + num_decode_tokens = attn_metadata.num_decode_tokens + if attn_metadata.num_prefills > 0: + cu_prefix_kv_lens = torch.tensor( + [0] + attn_metadata.prefill_seq_lens.tolist(), + device=attn_metadata.prefill_seq_lens.device, + dtype=torch.int32, + ).cumsum(dim=0, dtype=torch.int32) + output[num_decode_tokens:num_actual_tokens] = ( + flash_attn_varlen_func( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + block_table=attn_metadata.prefill_block_table, + cu_seqlens_q=attn_metadata.prefill_query_start_loc, + cu_seqlens_k=cu_prefix_kv_lens, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.prefill_max_seq_len, + softmax_scale=self.scale, + causal=attn_metadata.causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + s_aux=self.sinks, + ) + ) + if attn_metadata.num_decodes > 0: + # Use flash_attn_with_kvcache for normal decoding. + decode_query = query[:num_decode_tokens] + decode_query = reshape_query_for_spec_decode( + decode_query, attn_metadata.num_decodes + ) + output_unreshape = flash_attn_with_kvcache( + q=decode_query, + k_cache=key_cache, + v_cache=value_cache, + block_table=attn_metadata.decode_block_table, + cache_seqlens=attn_metadata.decode_seq_lens, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + s_aux=self.sinks, + ) + output[:num_decode_tokens] = reshape_attn_output_for_spec_decode( + output_unreshape + ) + return output + # └------------------------- Metax Modification -------------------------┘ + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, + fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + s_aux=self.sinks, + ) + return output + + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + # /------------------------ Metax Modification -------------------------\ + assert attn_metadata.dcp_context_kv_lens is not None + cu_seqlens_k = torch.tensor( + [0] + attn_metadata.dcp_context_kv_lens.tolist(), + device=attn_metadata.dcp_context_kv_lens.device, + dtype=torch.int32, + ).cumsum(dim=0, dtype=torch.int32) + + context_attn_out, context_lse, _ = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_attn_probs=True, + # scheduler_metadata=attn_metadata.scheduler_metadata, + # fa_version=self.vllm_flash_attn_version, + # q_descale=q_descale, + # k_descale=k_descale, + # v_descale=v_descale, + s_aux=self.sinks, + ) + # \------------------------- Metax Modification -------------------------/ + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + # /------------------------ Metax Modification -------------------------\ + query_attn_out, query_lse, _ = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_attn_probs=True, + # fa_version=self.vllm_flash_attn_version, + # q_descale=q_descale, + # k_descale=k_descale, + # v_descale=v_descale, + s_aux=self.sinks, + ) + # \------------------------- Metax Modification -------------------------/ + + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention" + ) + + # Use encoder-specific metadata for sequence information + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_k = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_query_len + + descale_shape = ( + cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] + self.num_kv_heads, + ) + + num_actual_tokens = attn_metadata.num_actual_tokens + # Call flash attention directly on Q, K, V tensors + output[:num_actual_tokens] = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=False, # Encoder attention is bidirectional + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + # fa_version=self.vllm_flash_attn_version, + # q_descale=layer._q_scale.expand(descale_shape), + # k_descale=layer._k_scale.expand(descale_shape), + # v_descale=layer._v_scale.expand(descale_shape), + s_aux=self.sinks, + # num_splits=1 if self.batch_invariant_enabled else 0, + ) + + return output + + +def use_cascade_attention( + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + use_local_attention: bool, + num_sms: int, + dcp_world_size: int, +) -> bool: + """Decide whether to use cascade attention. + + This function 1) checks whether cascade attention is supported with the + given configuration, and 2) heuristically decides whether using cascade + attention can improve performance. + """ + # Too short common prefix. Probably not worth using cascade attention. + # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. + # NOTE(woosuk): This is the common case. We should return False as soon as + # possible to avoid any unnecessary computation. + if common_prefix_len < 256: + return False + # Cascade attention is currently not supported with these variants. + if use_alibi or use_sliding_window or use_local_attention: + return False + # Too few queries. Probably not worth using cascade attention. + # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. + num_reqs = len(query_lens) + if num_reqs < 8: + return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False + + # Heuristics to decide whether using cascade attention is beneficial. + # 1. When FlashDecoding is not used for normal attention, cascade attention + # is likely to be faster since it saves memory bandwidth. + num_queries_per_kv = num_query_heads // num_kv_heads + # The criteria for using FlashDecoding can be found in the following link: + # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) + if not use_flash_decoding: + # Use cascade attention. + return True + + # 2. When FlashDecoding is used for normal attention, it is not clear + # whether cascade attention is beneficial, because FlashDecoding can + # launch more CTAs than cascade attention. + # We use a simple performance model to compare the two methods. + # NOTE(woosuk): The performance model is very rough and may not be + # accurate. + num_tokens = num_reqs + # NOTE(woosuk): These are default tile sizes. flash-attn might use + # different tile sizes (e.g., 64 or 256) depending on the configuration. + q_tile_size = 128 + kv_tile_size = 128 + num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) + + cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) + cascade_waves = cdiv(cascade_ctas, num_sms) + cascade_time = cascade_waves * num_prefix_tiles + + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) + flash_decoding_ctas *= num_prefix_tiles + flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) + + # Use cascade attention if it is faster than FlashDecoding. + return cascade_time < flash_decoding_time + + +def cascade_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_query_lens: torch.Tensor, + max_query_len: int, + cu_prefix_query_lens: torch.Tensor, + prefix_kv_lens: torch.Tensor, + suffix_kv_lens: torch.Tensor, + max_kv_len: int, + softmax_scale: float, + alibi_slopes: torch.Tensor | None, + sliding_window: tuple[int, int], + logits_soft_cap: float, + block_table: torch.Tensor, + common_prefix_len: int, + max_num_splits: int, + fa_version: int, + prefix_scheduler_metadata: torch.Tensor | None = None, + suffix_scheduler_metadata: torch.Tensor | None = None, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, +) -> torch.Tensor: + assert alibi_slopes is None, "Cascade attention does not support ALiBi." + # TODO: Support sliding window. + assert sliding_window == (-1, -1), ( + "Cascade attention does not support sliding window." + ) + + num_tokens = query.shape[0] + block_size = key_cache.shape[-3] + assert common_prefix_len % block_size == 0 + num_common_kv_blocks = common_prefix_len // block_size + assert num_common_kv_blocks > 0 + descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) + + # /------------------------ Metax Modification -------------------------\ + cu_prefix_kv_lens = torch.tensor( + [0] + prefix_kv_lens.tolist(), device=prefix_kv_lens.device, dtype=torch.int32 + ).cumsum(dim=0, dtype=torch.int32) + # \------------------------ Metax Modification -------------------------/ + + # Process shared prefix. + # /------------------------ Metax Modification -------------------------\ + prefix_output, prefix_lse, _ = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_prefix_query_lens, + cu_seqlens_k=cu_prefix_kv_lens, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len, + softmax_scale=softmax_scale, + causal=False, + window_size=sliding_window, + block_table=block_table[:1], + softcap=logits_soft_cap, + s_aux=s_aux, + ) + # \------------------------- Metax Modification -------------------------/ + + descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) + # /------------------------ Metax Modification -------------------------\ + cu_suffix_kv_lens = torch.tensor( + [0] + suffix_kv_lens.tolist(), device=suffix_kv_lens.device, dtype=torch.int32 + ).cumsum(dim=0, dtype=torch.int32) + # \------------------------ Metax Modification -------------------------/ + + # Process suffix per query. + suffix_output, suffix_lse, _ = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, + softmax_scale=softmax_scale, + causal=True, + window_size=sliding_window, + block_table=block_table[:, num_common_kv_blocks:], + softcap=logits_soft_cap, + return_attn_probs=True, + s_aux=s_aux, + ) + + # Merge prefix and suffix outputs, and store the result in output. + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm_fl/v1/attention/backends/flashinfer.py b/vllm_fl/v1/attention/backends/flashinfer.py new file mode 100644 index 00000000..25b69cc1 --- /dev/null +++ b/vllm_fl/v1/attention/backends/flashinfer.py @@ -0,0 +1,1472 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashInfer.""" + +from dataclasses import dataclass +from typing import Any, ClassVar + +import numpy as np +import torch +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, +) +from flashinfer.decode import _get_range_buf +from typing_extensions import override + +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) +from vllm.attention.ops.common import cp_lse_ag_out_rs + +# -------------------------------------------------------------- +# Note: use Maca's merge_attn_states to get cuda kernel invoked +# -------------------------------------------------------------- +from vllm_fl.attention.ops.merge_attn_states import merge_attn_states +from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config +from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability +from vllm.triton_utils import tl, triton +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + flashinfer_disable_q_quantization, + use_trtllm_attention, +) +from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + KVCacheLayoutType, + get_dcp_local_seq_lens, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +# /----------------------------------------------------------------------\ +# Note: Currently `flashinfer+metax` does not support: +# - MultiLevelCascadeAttentionWrapper +# - trtllm_batch_context_with_kv_cache +# - trtllm_batch_decode_with_kv_cache +# - FP4Tensor +# \----------------------------------------------------------------------/ + + +FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 + +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + +logger = init_logger(__name__) + +trtllm_gen_workspace_buffer = None + + +def _get_trtllm_gen_workspace_buffer(): + global trtllm_gen_workspace_buffer + if trtllm_gen_workspace_buffer is None: + trtllm_gen_workspace_buffer = torch.zeros( + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + ) + return trtllm_gen_workspace_buffer + + +@triton.jit +def _trtllm_prefill_attn_kvfp8_dequant( + kv_cache_ptr, + block_tables_prefill_ptr, + block_table_stride, + mock_kv_cache_ptr, + k_scale_ptr, + v_scale_ptr, + K_CACHE_STRIDE: tl.constexpr, + KV_CACHE_STRIDE: tl.constexpr, +): + batch_idx = tl.program_id(0).to(tl.int64) + mock_block_table_idx = tl.program_id(1).to(tl.int64) + orig_page_num = tl.load( + block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx + ).to(tl.int64) + if orig_page_num <= 0: + return + dequant_dtype = mock_kv_cache_ptr.dtype.element_ty + + # Dequantize K + k_scale_val = tl.load(k_scale_ptr) + offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val + mock_cache_offset = ( + batch_idx * block_table_stride + mock_block_table_idx + 1 + ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + # Dequantize V + v_scale_val = tl.load(v_scale_ptr) + offset = ( + orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + ) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val + mock_cache_offset = ( + (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE + + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE) + ) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + +def trtllm_prefill_attn_kvfp8_dequant( + kv_cache: torch.Tensor, + block_tables_prefill: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + dequant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_of_page_per_token = block_tables_prefill.shape + s = kv_cache.shape + assert s[1] == 2 + assert dequant_dtype in (torch.bfloat16, torch.float16) + k_cache_stride = s[2] * s[3] * s[4] + kv_cache_stride = k_cache_stride * s[1] + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) + # mock kv cache contains just the pages needed by this prefill + mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) + # we simply sequentially index the pages needed by this prefill + mock_block_table = torch.arange( + start=1, + end=batch_size * num_of_page_per_token + 1, + dtype=torch.int32, + device=block_tables_prefill.device, + ).reshape(batch_size, num_of_page_per_token) + grid = (batch_size, num_of_page_per_token) + _trtllm_prefill_attn_kvfp8_dequant[grid]( + kv_cache, + block_tables_prefill, + num_of_page_per_token, + mock_kv_cache, + k_scale, + v_scale, + k_cache_stride, + kv_cache_stride, + ) + return mock_kv_cache, mock_block_table + + +class BatchDCPPrefillWrapper: + def __init__( + self, + workspace_buffer: torch.Tensor | None = None, + ): + self._context = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + + def plan( + self, + qo_indptr_cpu: torch.Tensor, + paged_kv_indptr_cpu: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len_cpu: torch.Tensor, + prefill_start: int, + page_size: int, + num_qo_heads: int, + dcp_world_size: int, + num_kv_heads: int, + head_dim: int, + sm_scale: float, + window_left: int, + logits_soft_cap: float | None, + q_data_type: torch.dtype, + kv_cache_dtype: torch.dtype, + prefill_fixed_split_size: int, + disable_split_kv: bool, + ): + """Plan the prefill operation with given parameters.""" + self._context.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + num_qo_heads * dcp_world_size, + num_kv_heads, + head_dim, + page_size, + causal=False, # This is context run + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + kv_data_type=kv_cache_dtype, + # fixed_split_size=prefill_fixed_split_size, + # disable_split_kv=disable_split_kv, + ) + self._new_tokens.plan( + qo_indptr=qo_indptr_cpu, + kv_indptr=qo_indptr_cpu, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + causal=True, # This is newtokens run + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + ) + + def run( + self, + layer: torch.nn.Module, + prefill_query: torch.Tensor, + kv_cache_permute: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + ): + prefill_query_across_dcp = get_dcp_group().all_gather( + prefill_query.contiguous(), dim=1 + ) + output_context_tmp, lse_context_tmp = self._context.run( + prefill_query_across_dcp, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + return_lse=True, + ) + output_context, lse_context = cp_lse_ag_out_rs( + output_context_tmp, + lse_context_tmp, + get_dcp_group(), + return_lse=True, + is_lse_base_on_e=False, + ) + lse_context = lse_context.transpose(0, 1).contiguous() + + output_query, lse_query = self._new_tokens.run( + prefill_query, + key, + value, + return_lse=True, + ) + lse_query = lse_query.transpose(0, 1).contiguous() + + merge_attn_states( + out, + output_context, + lse_context, + output_query, + lse_query, + ) + return out + + +@register_backend(AttentionBackendEnum.FLASHINFER) +class MacaFlashInferBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + # Note: Not sure for all platforms, but on Blackwell, + # only support a page size of 16, 32, 64. + return [16, 32, 64] + + @staticmethod + def get_name() -> str: + return "FLASHINFER" + + @staticmethod + def get_impl_cls() -> type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_builder_cls() -> type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets us from + # `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (1, 0, 2, 3, 4, 5) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) + return (1, 2, 4, 0, 3, 5) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + raise NotImplementedError("Maca does not support FP8 FlashInfer.") + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + return [64, 128, 256] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + # On MACA it's 8.0 + return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability( + 12, 1 + ) + + @classmethod + def supports_sink(cls) -> bool: + """FlashInfer supports sinks when TRTLLM attention is available (SM100).""" + from vllm.utils.flashinfer import ( + force_use_trtllm_attention, + supports_trtllm_attention, + ) + + # Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0) + if force_use_trtllm_attention() is False: + return False + + # Check if TRTLLM is supported on this platform + return supports_trtllm_attention() + + @classmethod + def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability is not None and capability.major == 10: + return "HND" + return None + + +@dataclass +class FlashInferMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + + # The data type of the query + q_data_type: torch.dtype + + slot_mapping: torch.Tensor + + # For flashinfer trtllm batch decode + max_q_len: int + max_q_len_prefill: int + max_seq_len: int + seq_lens: torch.Tensor + block_table_tensor: torch.Tensor + prefill_use_trtllm: bool + decode_use_trtllm: bool + + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + # For cascade attention (CPU for planning). + use_cascade: bool + + prefill_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + ) = None + decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None + # /------------------------ Metax Modification -------------------------\ + cascade_wrapper: Any = None + # \------------------------- Metax Modification -------------------------/ + + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None + + +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + reorder_batch_threshold: int = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config + self.attention_config = vllm_config.attention_config + self._workspace_buffer = None + self._prefill_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + ) = None # Wrapper for prefill/append + self._decode_wrapper = None # Wrapper for decode (general shape) + + if vllm_is_batch_invariant(): + self.decode_fixed_split_size = 2048 + self.prefill_fixed_split_size = 4096 + self.disable_split_kv = True + else: + self.decode_fixed_split_size = -1 + self.prefill_fixed_split_size = -1 + self.disable_split_kv = False + + self.compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + speculative_config = vllm_config.speculative_config + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if speculative_config is not None + else 0 + ) + self.enable_cuda_graph = ( + self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + ) + if self.enable_cuda_graph: + # For full cudagraph capture, one `decode_wrapper` for each batch + # size is needed for FlashInfer. + self._decode_wrappers_cudagraph: dict[ + int, BatchDecodeWithPagedKVCacheWrapper + ] = {} + self._decode_cudagraph_max_bs = min( + (1 + num_spec_tokens) * max_num_reqs, + self.compilation_config.max_cudagraph_capture_size, + ) + + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + self.dcp_kv_cache_interleave_size = ( + vllm_config.parallel_config.dcp_kv_cache_interleave_size + ) + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = 1 + + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config + ) + + self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.head_dim = self.kv_cache_spec.head_size + self.page_size = self.kv_cache_spec.block_size + + self.cache_dtype = self.cache_config.cache_dtype + if self.cache_dtype.startswith("fp8"): + self.kv_cache_dtype = MacaFlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype + ) + else: + assert self.kv_cache_spec.dtype == self.model_config.dtype + self.kv_cache_dtype = self.kv_cache_spec.dtype + + # Use model dtype as q dtype when TRTLLM attn is not supported, or + # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise, + # try to use fp8 q if kv cache is fp8, and will fall back to model dtype + # if TRTLLM attention kernel is not used when building attn metadata + can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + if ( + can_use_trtllm + and not vllm_config.attention_config.disable_flashinfer_q_quantization + ): + self.q_data_type = self.kv_cache_dtype + else: + self.q_data_type = self.model_config.dtype + + # Prefer TRTLLM attention for decoding in all cases. + # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode. + self.use_trtllm_decode_attention = can_use_trtllm + self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) + + self._cascade_wrapper = None # Wrapper for cascade attention + + # Global hyperparameters shared by all attention layers + # TODO: discard this for trtllm-gen backend + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl) + ) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks + if self.has_sinks and not can_use_trtllm: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention on " + "earlier GPUs." + ) + # Preparing persistent buffers (device-side) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, # max num pages possible + dtype=torch.int32, + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) + # host-side buffer + pin_memory = is_pin_memory_available() + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indptr_buffer = torch.zeros_like( + self.paged_kv_indptr_cpu, pin_memory=pin_memory + ) + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() + + if self.head_dim == 256: + # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that + # head size 256 and block size 16 is not supported on blackwell. + assert kv_cache_spec.block_size != 16, ( + "There is a bug in FlashInfer " + "block_size 16 head size 256 support. Please avoid this combination by " + "passing --block-size 32 or --block-size 64." + ) + + @classmethod + def get_cudagraph_support( + cls: type["FlashInferMetadataBuilder"], + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + has_trtllm_support = can_use_trtllm_attention( + num_qo_heads=vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ), + num_kv_heads=kv_cache_spec.num_kv_heads, + ) + if has_trtllm_support: + return AttentionCGSupport.UNIFORM_BATCH + else: + return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE + if vllm_is_batch_invariant(): + buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT + self._workspace_buffer = torch.zeros( + buffer_size, dtype=torch.uint8, device=self.device + ) + return self._workspace_buffer + + def set_workspace_buffer(self, workspace_buffer: torch.Tensor): + self._workspace_buffer = workspace_buffer + + def _get_prefill_wrapper( + self, + ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: + if self._prefill_wrapper is None: + if self.dcp_world_size > 1: + self._prefill_wrapper = BatchDCPPrefillWrapper( + workspace_buffer=self._get_workspace_buffer(), + ) + else: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) + assert self._prefill_wrapper is not None + return self._prefill_wrapper + + def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): + if use_cudagraph: + decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) + else: + decode_wrapper = self._decode_wrapper + + if decode_wrapper is None: + if use_cudagraph: + paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] + paged_kv_indices = self.paged_kv_indices + paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] + else: + paged_kv_indptr = None + paged_kv_indices = None + paged_kv_last_page_len = None + decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + get_kv_cache_layout(), + use_cuda_graph=use_cudagraph, + paged_kv_indptr_buffer=paged_kv_indptr, + paged_kv_indices_buffer=paged_kv_indices, + paged_kv_last_page_len_buffer=paged_kv_last_page_len, + # Tensor cores are enabled by default because the perf would be + # at least as good as cuda cores for all attention ops in latest + # gpus. + use_tensor_cores=True, + ) + + # save the decode wrapper + if use_cudagraph: + self._decode_wrappers_cudagraph[batch_size] = decode_wrapper + else: + self._decode_wrapper = decode_wrapper + + return decode_wrapper + + # /------------------------ Metax Modification -------------------------\ + # Note: MetaX won't support cascade attention for now + # \------------------------ Metax Modification -------------------------/ + def _get_cascade_wrapper(self): + raise NotImplementedError( + "Maca FlashInfer backend does not support cascade attention." + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashInferMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) + + page_size = self.page_size + max_q_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + + if self.dcp_world_size > 1: + if num_prefills > 0: + qo_indptr_prefill_cpu = ( + qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] + ) + query_lens_prefill_cpu = ( + qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] + ) + seq_lens_cpu[num_decodes:] = ( + seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu + ) + + seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + + seq_lens_np = seq_lens_cpu.numpy() + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size + + # Metax does not support cascade attention for now + use_cascade = False and common_prefix_len > 0 + if use_cascade: + # Grab the blocks of the shared prefix from the first request. + assert common_prefix_len % page_size == 0 + num_common_kv_blocks = common_prefix_len // page_size + + # Create CPU versions directly for cascade (no GPU versions needed) + shared_qo_indptr_cpu = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indptr_cpu = torch.tensor( + [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor( + [page_size], dtype=torch.int32, device="cpu" + ) + + # Remove the blocks of the shared prefix from all requests. + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] + num_blocks_np -= num_common_kv_blocks + else: + shared_qo_indptr_cpu = None + shared_kv_page_indptr_cpu = None + shared_kv_page_indices_cpu = None + shared_kv_last_page_len_cpu = None + + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1 : num_reqs + 1], + ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + ) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr_np[num_reqs] + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs,)]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) + + # write self.paged_kv_last_page_len_cpu inplace + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), + page_size, + paged_kv_last_page_len_np, + ) + + uses_spec_reorder = self.reorder_batch_threshold > 1 + prefill_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.dcp_world_size, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + force_use_trtllm=self.attention_config.use_trtllm_attention, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) + decode_use_trtllm = ( + self.use_trtllm_decode_attention and self.dcp_world_size <= 1 + ) + + if not (prefill_use_trtllm and decode_use_trtllm): + if self.has_sinks: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention " + "on earlier GPUs." + ) + + if not self.global_hyperparameters.has_same_window_lefts: + raise ValueError( + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) + + assert self.global_hyperparameters.has_same_all_params, ( + "FlashInfer backend currently only supports models in which " + "all layers share the same values for the following " + "hyperparameters: `window_left`, `logits_soft_cap`, " + "`sm_scale`." + ) + + # The q quantization is not supported for non-trtllm attention, + # fall back to model dtype. + self.q_data_type = self.model_config.dtype + + attn_metadata = FlashInferMetadata( + num_actual_tokens=num_actual_tokens, + q_data_type=self.q_data_type, + slot_mapping=common_attn_metadata.slot_mapping, + max_q_len=max_q_len, + max_q_len_prefill=max_q_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table_tensor=block_table_tensor, + prefill_use_trtllm=prefill_use_trtllm, + decode_use_trtllm=decode_use_trtllm, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + use_cascade=use_cascade, + ) + + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + + if attn_metadata.use_cascade: + raise NotImplementedError( + "Maca FlashInfer backend does not support cascade attention." + ) + else: + # Regular attention (common case). + # Decodes are at the front and prefills are at the back. + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decodes + if num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert ( + paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills + ) + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) + paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + + # Recompute max_q_len for the slice of requests we are using + # for prefills. This can be different from max_q_len when + # we have a non-uniform batch with some short decodes offloaded + # to the prefill pathway + query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] + attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) + + if not attn_metadata.prefill_use_trtllm: + if self.dcp_world_size > 1: + assert isinstance( + attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper + ) + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu=qo_indptr_cpu, + paged_kv_indptr_cpu=paged_kv_indptr_cpu, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, + prefill_start=prefill_start, + page_size=self.page_size, + num_qo_heads=self.num_qo_heads, + dcp_world_size=self.dcp_world_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_cache_dtype=self.kv_cache_dtype, + prefill_fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + else: + assert isinstance( + attn_metadata.prefill_wrapper, + BatchPrefillWithPagedKVCacheWrapper, + ) + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + else: + attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( + self.device, non_blocking=True + ) + attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( + self.device, non_blocking=True + ) + + if num_decodes > 0: + pure_decode = num_prefills == 0 + use_cudagraph = ( + self.enable_cuda_graph + and pure_decode + and num_decode_tokens <= self._decode_cudagraph_max_bs + ) + num_input_tokens = num_decode_tokens + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens, use_cudagraph + ) + if not attn_metadata.decode_use_trtllm: + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + attn_metadata.decode_wrapper, + self.paged_kv_indptr_cpu[: num_input_tokens + 1], + paged_kv_indices, + self.paged_kv_last_page_len_cpu[:num_input_tokens], + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads * self.dcp_world_size, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.decode_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: + # TODO: The cascade wrapper currently does not support setting + # kv cache dtype to something different from query dtype. + return False + # TODO: Cascade attention doesn't work, disable it for now + # return use_cascade_attention(*args, **kwargs) + return False + + +class FlashInferImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.window_left = ( + self.sliding_window[0] if self.sliding_window is not None else -1 + ) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl" + ) + + self.sinks: torch.Tensor | None = None + # if sinks is not None: + # if sinks.shape[0] != num_heads: + # raise ValueError( + # "Sinks must have the same number of heads as the number of " + # f"heads in the layer. Expected {num_heads}, but got " + # f"{sinks.shape[0]}." + # ) + # self.sinks = sinks + self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) + vllm_config = get_current_vllm_config() + self.supports_quant_query_input = ( + self.support_trtllm_attn + and not vllm_config.attention_config.disable_flashinfer_q_quantization + ) + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + self.o_sf_scale: float | None = None + + def fused_output_quant_supported(self, quant_key: QuantKey): + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + ) + + # FlashInfer requires attention sinks to be float32 + def process_weights_after_loading(self, act_dtype: torch.dtype): + if self.sinks is not None and self.sinks.dtype != torch.float32: + self.sinks = self.sinks.to(torch.float32) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashInfer. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: KV cache tensor with different possible shapes: + - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + # Ensure query dtype matches the expected dtype from attention metadata + assert attn_metadata.q_data_type == query.dtype, ( + f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " + f"got {query.dtype}" + ) + + if self.bmm1_scale is None: + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + # The attn+quant fusion happens when output_scale is provided. + if output_scale is None: + assert output_block_scale is None, ( + "output_block_scale is not supported when fusion has not happened" + ) + else: + assert attn_metadata.q_data_type == FP8_DTYPE, ( + "Query must be FP8 when attn+quant fusion happened." + ) + assert ( + attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + ), "Must use TRT-LLM attn" + + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, ( + "output_block_scale should not be provided for fp8 output" + ) + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, ( + "output_block_scale is required for nvfp4 output" + ) + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires to scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled + if layer._o_scale_float is None: + layer._o_scale_float = output_scale.cpu().item() + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = MacaFlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype + ) + kv_cache = kv_cache.view(torch_dtype) + + # Inputs and outputs may be padded for CUDA graphs + query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] + output_padded = output + output = output[:num_actual_tokens] + + if attn_metadata.use_cascade: + # Cascade attention (rare case). + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) + return output + + # When using spec decoding, num_decodes can be < num_decode_tokens + # because some decode requests may have more than one query token. + num_decodes = attn_metadata.num_decodes + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + stride_order = MacaFlashInferBackend.get_kv_cache_stride_order() + kv_cache_permute = kv_cache.permute(*stride_order) + # Regular attention (common case). + # Decodes are at the front and prefills are at the back. + if num_prefill_tokens > 0: + prefill_wrapper = attn_metadata.prefill_wrapper + prefill_query = query[num_decode_tokens:] + assert prefill_query.shape[0] == num_prefill_tokens + assert prefill_wrapper is not None + + if not attn_metadata.prefill_use_trtllm: + if self.dcp_world_size > 1: + assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) + assert prefill_wrapper._context._window_left == self.window_left + assert prefill_wrapper._context._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._context._sm_scale == self.scale + assert not prefill_wrapper._context._causal + assert prefill_wrapper._new_tokens._window_left == self.window_left + assert prefill_wrapper._new_tokens._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._new_tokens._sm_scale == self.scale + assert prefill_wrapper._new_tokens._causal + + prefill_wrapper.run( + layer, + prefill_query, + kv_cache_permute, + key[num_decode_tokens:], + value[num_decode_tokens:], + out=output[num_decode_tokens:], + ) + else: + assert isinstance( + prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper + ) + assert prefill_wrapper._window_left == self.window_left + assert prefill_wrapper._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._sm_scale == self.scale + assert prefill_wrapper._causal + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) + else: + # prefill_query may be non-contiguous + raise NotImplementedError( + "prefill_use_trtllm prefill attention is not implemented on MACA yet." + ) + + if num_decode_tokens > 0: + decode_wrapper = attn_metadata.decode_wrapper + decode_query = query[:num_decode_tokens] + assert decode_query.shape[0] == num_decode_tokens + assert decode_wrapper is not None + + if not attn_metadata.decode_use_trtllm: + assert decode_wrapper._window_left == self.window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert decode_wrapper._sm_scale == self.scale + + if self.dcp_world_size > 1: + decode_query = get_dcp_group().all_gather( + decode_query.contiguous(), dim=-2 + ) + output_tmp = torch.empty_like(decode_query) + lse = torch.empty( + (decode_query.size(0), decode_query.size(1)), + dtype=torch.float32, + device=decode_query.device, + ) + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output_tmp, + lse=lse, + return_lse=True, + ) + output[:num_decode_tokens] = cp_lse_ag_out_rs( + output_tmp, + lse, + get_dcp_group(), + is_lse_base_on_e=False, + ) + else: + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) + else: + raise NotImplementedError( + "prefill_use_trtllm prefill attention is not implemented on MACAyet." + ) + return output_padded + + +def fast_plan_decode( + self, # decode wrapper + indptr_cpu: torch.Tensor, + indices: torch.Tensor, + last_page_len_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + logits_soft_cap: float | None = None, + q_data_type: str | torch.dtype | None = "float16", + kv_data_type: str | torch.dtype | None = None, + data_type: str | torch.dtype | None = None, + sm_scale: float | None = None, + rope_scale: float | None = None, + rope_theta: float | None = None, + non_blocking: bool = True, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> None: + """ + A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for + cudagraph capture/replay, while the no cudagraph version turns back + to the original plan. + using original plan after passing host-side buffers: + - only host-to-device copy of indptr and last_page_len buffers + Modifications for cudagraph: + - only host-to-device copy of indptr and last_page_len buffers. + - avoid device-to-device copy of indices buffer. + + Part of the code get inspiration from the original plan from FlashInfer repo + and the implementation of fast_decode_plan for FlashInfer in SGlang repo. + """ + # Warm up with the original plan if it is first call, and always run the + # original plan if we run for dynamic shape. For fixed shape (cudagraph), + # this warm up is to generate the _cached_module for the decode wrapper. + if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): + self.plan( + indptr_cpu, + indices, + last_page_len_cpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode, + window_left, + logits_soft_cap, + q_data_type, + kv_data_type, + data_type, + sm_scale, + rope_scale, + rope_theta, + non_blocking, + # --- maca does not support args below + # None, # block_tables + # None, # seq_lens + # fixed_split_size, + # disable_split_kv, + ) + self.vllm_first_call = False + return + + assert self.is_cuda_graph_enabled, "Should be cudagraph only here" + + batch_size = len(last_page_len_cpu) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + + # Handle data types consistently + if data_type is not None: + if q_data_type is None: + q_data_type = data_type + if kv_data_type is None: + kv_data_type = data_type + elif q_data_type is None: + q_data_type = "float16" + + if kv_data_type is None: + kv_data_type = q_data_type + q_data_type = ( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ) + kv_data_type = ( + getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type + ) + + if batch_size != self._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime " + "batch size {} mismatches the batch size set during " + "initialization {}".format(batch_size, self._fixed_batch_size) + ) + if len(indices) > len(self._paged_kv_indices_buf): + raise ValueError( + "The size of indices should be less than or equal to the allocated buffer" + ) + + # host-to-device copy for the indptr buffer + self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) + # host-to-device copy for the last_page_len buffer + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) + + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + + try: + # Make sure we pass exactly 15 arguments for tensor core version + device = self._float_workspace_buffer.device + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_cpu, + seq_lens_cpu, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + torch.cuda.current_stream(device).cuda_stream, + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e + + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) diff --git a/vllm_fl/v1/attention/backends/flex_attention.py b/vllm_fl/v1/attention/backends/flex_attention.py new file mode 100644 index 00000000..ea0994a2 --- /dev/null +++ b/vllm_fl/v1/attention/backends/flex_attention.py @@ -0,0 +1,1031 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlexAttention.""" + +import math +from dataclasses import dataclass +from functools import cached_property +from typing import ClassVar + +import torch +import torch._dynamo.decorators +import torch.nn.functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + _score_mod_signature, + and_masks, + create_block_mask, + flex_attention, + or_masks, +) + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + is_quantized_kv_cache, +) +from vllm.config import VllmConfig +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +logger = init_logger(__name__) + +torch._dynamo.config.recompile_limit = 16 +create_block_mask_compiled = torch.compile( + create_block_mask, fullgraph=True, mode="reduce-overhead" +) +flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) + + +def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave( + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) + + +def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): + difference = (multiple - (x.shape[dim] % multiple)) % multiple + if difference == 0: + return x + + dim = dim if dim >= 0 else x.ndim + dim + pad_list = [] + + for i in range(x.ndim - 1, dim - 1, -1): + if i == dim: + pad_list.extend([0, difference]) + else: + pad_list.extend([0, 0]) + + return F.pad(x, pad_list, mode="constant", value=0) + + +@register_backend(AttentionBackendEnum.FLEX_ATTENTION) +class MacaFlexAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + + @staticmethod + def get_name() -> str: + return "FLEX_ATTENTION" + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlexAttention supports both decoder and encoder-only attention.""" + return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) + + @classmethod + def supports_mm_prefix(cls) -> bool: + """FlexAttention supports full attention for image tokens.""" + return True + + @staticmethod + def get_impl_cls() -> type["FlexAttentionImpl"]: + return FlexAttentionImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]: + return FlexAttentionMetadataBuilder + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + +# @torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping( + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + total_blocks: int, +) -> torch.Tensor: + """ + Creates an inverse mapping from physical block locations to logical indices. + + The original block_table maps from logical blocks to physical locations: + + Logical to Physical (Original block_table): + ┌───────────────────────────────────────────┐ + │ Request 0: │ + │ │ + │ Logical Blocks: 0 1 2 3 4 5 6 7 │ + │ │ │ │ │ │ │ │ │ │ + │ v v v v v v v v │ + │ Physical Blocks: 3 5 1 7 4 2 0 6 │ + └───────────────────────────────────────────┘ + + This function creates the inverse mapping: + + Physical to Logical (Inverse mapping): + ┌───────────────────────────────────────────┐ + │ Request 0: │ + │ │ + │ Physical Blocks: 0 1 2 3 4 5 6 7 │ + │ │ │ │ │ │ │ │ │ │ + │ v v v v v v v v │ + │ Logical Blocks: 6 2 5 0 4 1 7 3 │ + └───────────────────────────────────────────┘ + + If multiple logical blocks map to the same physical block, + this function returns the latest (maximum) logical block index. + + If a physical block is not mapped to by any logical block, + its value in the result will be -1. + + IMPORTANT: Garbage Value Protection + ──────────────────────────────────── + The block_table tensor may contain garbage values in unused positions + (beyond the actual sequence length). For example, if a sequence only + needs 3 blocks but the table has space for 8: + + block_table[0] = [10, 25, 7, 999, 1234, 888, ...] + ^^^^^^^^^^^^^^^^^^^^ + garbage values + + These garbage values can cause issues because: + 1. They may map to valid physical blocks by coincidence + 2. The scatter_ operation will assign them logical indices + 3. Later attention computations may incorrectly access these blocks + + To prevent this, we use seq_lens and block_size to mask out unused + entries, ensuring only valid block references are processed. + + IMPORTANT: Reused physical blocks (sliding-window / hybrid attention) + ──────────────────────────────────────────────────────────────────── + For some attention types, physical cache blocks can be reused over time. + This can cause the same physical block id to appear multiple times in a row + of `block_table` at different logical block indices. In that case, only the + latest logical block index corresponds to the current contents of that + physical block. Therefore, the inverse mapping must pick the maximum logical + block index for each physical block id. + + Args: + block_table: Tensor of shape [max_reqs, max_num_blocks] + mapping logical blocks to physical locations. May contain + garbage values in unused positions. + seq_lens: Tensor of sequence lengths for each request. Used to + determine how many blocks are actually needed per sequence. + block_size: Size of each block in tokens. Used with seq_lens to + compute the number of valid blocks per sequence. + total_blocks: Total number of physical blocks available + + Returns: + A tensor of shape [max_reqs, total_blocks] where each entry + physical_to_logical[req_id, physical_block] contains the logical + block index for that physical block, or -1 if unused. + """ + max_reqs, max_num_blocks = block_table.shape + device = block_table.device + + physical_to_logical = torch.full( + (max_reqs, total_blocks), -1, dtype=torch.long, device=device + ) + + # Only process valid blocks to avoid garbage values + num_blocks_per_seq = cdiv(seq_lens, block_size) + mask = ( + torch.arange(max_num_blocks, device=device)[None, :] + < num_blocks_per_seq[:, None] + ) + + valid_block_table = torch.where(mask, block_table, 0) + valid_logical_indices = torch.where( + mask, torch.arange(max_num_blocks, device=device)[None, :], 0 + ) + + physical_to_logical.scatter_reduce_( + -1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax" + ) + # NB - Seems like block 0 is always empty so we reset it manually + physical_to_logical[:, 0] = -1 + return physical_to_logical + + +def unique_static_unsorted( + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots +) -> torch.Tensor: + """ + - Keeps the first occurrence of each non-zero value while preserving order, + then left-packs those uniques and fills the rest with `pad_val`. + - Returns (packed, keep_mask) with the *same shape* as `x`. + - Requires that all values be in the range [0, M] + - Skips ignored_val + + Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory. + + Example: + x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1] + """ + if not (-1 <= pad_val <= M): + raise ValueError("`pad_val` must lie in [-1, M]") + + # ── move `dim` to the end so we can treat tensor as [B, N] ────────── + dim = dim % x.ndim + x_perm = x.movedim(dim, -1) # shape [..., N] + B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1] + x_flat = x_perm.reshape(B, N) # [B, N] + + device = x.device + idx = torch.arange(N, device=device).expand(B, N) # per-row indices + + # ── build first-occurrence table for every v ∈ [0, M] ─────────────── + first_idx = torch.full((B, M + 1), N, device=device) # “∞” + # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i + first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") + + # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] + + # ── left-pack uniques into a fresh tensor ─────────────────────────── + dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go + packed_flat = torch.full_like(x_flat, pad_val) + + rows, src_cols = torch.nonzero(keep, as_tuple=True) + packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + + # ── restore original layout ───────────────────────────────────────── + packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) + return packed + + +def causal_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +): + return q_idx >= kv_idx + + +@dataclass +class FlexAttentionMetadata: + causal: bool + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None + + # Block info + total_cache_tokens: int + block_size: int + max_possible_sequence_length: int + num_reqs: int + physical_to_logical: torch.Tensor + decode_offset: torch.Tensor + num_blocks_per_seq: torch.Tensor + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + # Flex Metadata + num_blocks = 0 + block_mask: BlockMask | None = None + score_mod: _score_mod_signature | None = None + logical_mask_mod: _mask_mod_signature = causal_mask_mod + doc_ids: torch.Tensor | None = None + direct_build: bool = True + q_block_size: int = 16 + kv_block_size: int = 16 + transformed_score_mod: _score_mod_signature | None = None + sliding_window: int | None = None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + + @cached_property + def logical_block_ids(self): + return torch.arange( + cdiv(self.max_seq_len, self.block_size), + device=self.block_table.device, + dtype=torch.long, + ) + + def _convert_physical_to_logical( + self, + request_lookup: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert physical indices to logical indices for both query and kv. + + NB is_within_lower_bound: do sequences start on block_boundaries? + + Returns: + tuple of (is_valid, logical_q_idx, logical_kv_idx) + """ + # Map query indices to corresponding request indices + q_req = request_lookup[q_idx] + + # Convert physical KV indices to logical indices + physical_kv_block = physical_kv_idx // self.block_size + physical_kv_offset = physical_kv_idx % self.block_size + logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] + logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset + + # Determine valid kv indices + live_block = logical_block_idx >= 0 + within_upper_bound = logical_kv_idx < self.seq_lens[q_req] + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + # Convert physical query indices to logical indices + local_q_idx = q_idx - self.query_start_loc[q_req] + logical_q_idx = local_q_idx + self.decode_offset[q_req] + + return is_valid, logical_q_idx, logical_kv_idx + + def get_causal_mask_mod(self) -> _mask_mod_signature: + """Creates the mask_mod function for FlexAttention. + + This function creates the combined mask mod function that handles: + 1. The paged attention block mapping + 2. The mapping from packed query sequences to logical query entries + + It also by defaults adds the decoding offset to the query indices. + With this info we create the "logical" indices that are passed to + mask_mod functions. This allows mask mod functions to be agnostic to + layout of the query and key/value tensors. + """ + assert self.doc_ids is not None + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) + # Apply mask modification only for valid indices + return torch.where( + is_valid, + self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod + + def get_bidirectional_mask_mod(self) -> _mask_mod_signature: + """Creates the encoder mask_mod function for FlexAttention. + + Since the encoder bidirectional attention doesn't run with + KV cache, this function creates a mask based on the + packed query sequences. + """ + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + return request_lookup[q_idx] == request_lookup[kv_idx] + + return final_mask_mod + + def get_sliding_window_mask_mod(self) -> _mask_mod_signature: + """Creates the sliding window mask_mod function for FlexAttention. + + Note that the sliding window mask here is bidirectional, we need + to mask it with the bidirectional/causal mask for encoder/decoder. + """ + + if self.sliding_window is None: + raise ValueError("sliding_window must be set for sliding window attention") + + def sliding_window_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): + return torch.abs(q_idx - kv_idx) < self.sliding_window + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) + return torch.where( + is_valid, + sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod if self.causal else sliding_window_mask_mod + + def get_prefix_lm_mask_mod(self) -> _mask_mod_signature: + """Creates the prefix LM mask_mod function for FlexAttention.""" + + assert self.doc_ids is not None + request_lookup = self.doc_ids + + def prefix_lm_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + cu_q_idx: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ): + mask = torch.zeros_like(q_idx, dtype=torch.bool) + for req, doc_range_lst in (self.mm_prefix_range or {}).items(): + req_mask = request_lookup[cu_q_idx] == req + for start, end in doc_range_lst: + doc_mask_q = (q_idx >= start) & (q_idx <= end) + doc_mask_kv = (kv_idx >= start) & (kv_idx <= end) + mask = mask | (req_mask & doc_mask_q & doc_mask_kv) + return mask + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) + return torch.where( + is_valid, + prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod + + def get_mask_mod(self): + # Stage-1: initialize the base mask_mod + # (causal mask for decoder or bidirectional mask for encoder) + if self.causal: + mask_mod = self.get_causal_mask_mod() + else: + mask_mod = self.get_bidirectional_mask_mod() + # stage-2: add external mask_mod for special attention during + # forwarding runtime to create the combined mask_mod. + if self.sliding_window is not None: + # Add sliding window mask for sliding window attention + sliding_window_mask_mod = self.get_sliding_window_mask_mod() + mask_mod = and_masks(mask_mod, sliding_window_mask_mod) + if self.mm_prefix_range: + # Add prefix LM mask for vision-language prefix LM attention + prefix_lm_mask_mod = self.get_prefix_lm_mask_mod() + mask_mod = or_masks(mask_mod, prefix_lm_mask_mod) + return mask_mod + + def get_transformed_score_mod(self) -> _score_mod_signature | None: + """Creates the transformed score_mod function for FlexAttention. + + This function wraps the user's score_mod to handle physical-to-logical + index conversion, similar to how get_mask_mod works for mask functions. + """ + if self.score_mod is None: + return None + + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + user_score_mod = self.score_mod + + def transformed_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx + ) + ) + + return torch.where( + is_valid, + user_score_mod( + score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx + ), + -float("inf"), + ) + + return transformed_score_mod + + def _build_block_mask_direct(self) -> BlockMask: + """Direct block mask construction for standard causal attention. + + This method constructs the block mask directly using + BlockMask.from_kv_blocks which is much more efficient than the + generic create_block_mask approach. + + The direct path works as follows: + 1. For each query token, fetch blocks from block_table using max_seq_len + and exclude out of sliding window blocks if needed. + (this fetches more blocks than needed for shorter sequences) + 2. Group query tokens into chunks of q_block_size + 3. For each group, deduplicate the blocks using unique_static_unsorted + 4. Create BlockMask using the deduplicated block indices + + Over-estimation occurs when a group of q_block_size tokens contains + multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for + each sequence represented in the group, even though individual query + tokens may only need a subset of those blocks based on causal masking + and their position. + + """ + page_to_block_ratio = self.kv_block_size // self.block_size + if page_to_block_ratio != 1: + raise ValueError( + f"FlexAttention currently requires the cache block size " + f"({self.block_size}) to be equal to the kv_block_size " + f"({self.kv_block_size}). Please check your model's " + f"configuration." + ) + + used_pages = self.block_table[ + self.doc_ids, : cdiv(self.max_seq_len, self.block_size) + ] + + if self.sliding_window and self.causal: + device = used_pages.device + assert self.doc_ids is not None + token_indices = torch.arange( + self.doc_ids.shape[0], device=device, dtype=torch.long + ) + logical_q_idx = ( + token_indices + - self.query_start_loc[self.doc_ids] + + self.decode_offset[self.doc_ids] + ) + min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0) + min_block_idx = min_kv_idx // self.block_size + sliding_mask = self.logical_block_ids >= min_block_idx[:, None] + used_pages.masked_fill_(~sliding_mask, 0) + + used_pages_padded = pad_to_multiple( + used_pages, multiple=self.q_block_size, dim=0 + ) + used_pages_padded = used_pages_padded.reshape( + used_pages_padded.shape[0] // self.q_block_size, -1 + ) + used_pages_padded = used_pages_padded // page_to_block_ratio + kv_indices = unique_static_unsorted( + (used_pages_padded.long()), M=self.num_blocks + ).to(torch.int32) + + kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + block_mask_kwargs = { + "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), + "kv_num_blocks": kv_num_blocks[None, None], + "kv_indices": kv_indices[None, None], + "full_kv_num_blocks": None, + "full_kv_indices": None, + "BLOCK_SIZE": (self.q_block_size, self.kv_block_size), + "mask_mod": self.mask_mod, + } + + # compute_q_blocks parameter is available in PyTorch 2.9+ + if is_torch_equal_or_newer("2.9.0.dev0"): + block_mask_kwargs["compute_q_blocks"] = False + return BlockMask.from_kv_blocks(**block_mask_kwargs) + + def build_block_mask(self) -> BlockMask: + mask_mod = self.get_mask_mod() + kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens + return create_block_mask_compiled( + mask_mod, + None, + None, + self.num_actual_tokens, + kv_len, + device=self.block_table.device, + BLOCK_SIZE=(self.q_block_size, self.kv_block_size), + ) + + def __post_init__(self): + assert self.use_cascade is False, "Not implemented yet." + assert self.common_prefix_len == 0, "Not implemented yet." + assert self.cu_prefix_query_lens is None, "Not implemented yet." + assert self.prefix_kv_lens is None, "Not implemented yet." + assert self.suffix_kv_lens is None, "Not implemented yet." + # Create a lookup mapping from query indices -> request number + self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) + self.num_blocks = self.total_cache_tokens // self.block_size + + self.mask_mod = self.get_mask_mod() + self.transformed_score_mod = self.get_transformed_score_mod() + + if self.direct_build and self.causal: + self.block_mask = self._build_block_mask_direct() + else: + self.block_mask = self.build_block_mask() + + +class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.headdim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + supports_small_blocks = is_torch_equal_or_newer("2.9.0.dev0") + self.direct_build: bool = supports_small_blocks + self.q_block_size: int = 16 if supports_small_blocks else 128 + self.kv_block_size: int = self.block_size if supports_small_blocks else 128 + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlexAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + num_blocks_per_seq = cdiv(seq_lens, self.block_size) + + use_cascade = common_prefix_len > 0 + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + if use_cascade: + raise NotImplementedError("Not yet my friend") + + block_size = self.kv_cache_spec.block_size + max_possible_seq_len = self.model_config.max_model_len + num_gpu_blocks = self.cache_config.num_gpu_blocks + + assert num_gpu_blocks is not None, ( + "FlexAttention requires num_gpu_blocks to be set" + ) + total_cache_tokens = num_gpu_blocks * block_size + + inverse_block_table = physical_to_logical_mapping( + block_table_tensor, seq_lens, block_size, num_gpu_blocks + ) + + offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( + self.device, non_blocking=True + ) + + out = FlexAttentionMetadata( + causal=common_attn_metadata.causal, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + block_size=block_size, + max_possible_sequence_length=max_possible_seq_len, + num_reqs=num_reqs, + physical_to_logical=inverse_block_table, + total_cache_tokens=total_cache_tokens, + decode_offset=offset_tensor, + num_blocks_per_seq=num_blocks_per_seq, + # FIXME(Isotr0py): direct build has issue to build bidirectional + # attention block mask for encoder-only models, disable it temporarily. + # see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053 + direct_build=(self.direct_build and common_attn_metadata.causal), + q_block_size=self.q_block_size, + kv_block_size=self.kv_block_size, + ) + return out + + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + + +class FlexAttentionImpl(AttentionImpl): + sliding_window: int | None + alibi_slopes: torch.Tensor | None + logits_soft_cap: float | None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.attn_type = attn_type + + if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER): + raise NotImplementedError( + f"FlexAttention does not support {attn_type} attention" + ) + + if alibi_slopes is not None: + raise NotImplementedError( + "FlexAttention does not support alibi slopes yet." + ) + else: + self.alibi_slopes = None + + self.sliding_window = sliding_window + + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + if self.logits_soft_cap is not None: + raise NotImplementedError( + "FlexAttention does not support logits soft cap yet." + ) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("FlexAttention does not support kv sharing yet.") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlexAttention does not support quantized kv-cache. Yet" + ) + + @staticmethod + def view_as_4d(tensor: torch.Tensor) -> torch.Tensor: + """View a 3d tensor as 4D.""" + if tensor.ndim == 4: + return tensor + assert tensor.ndim == 3 + return tensor[None, :, :, :] + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlexAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FLexAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for FlexAttentionImpl" + ) + + enable_gqa = self.num_kv_heads != self.num_heads + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + # query = self.view_as_4d(query).permute(0, 2, 1, 3) + # return torch.empty_like(query) + + num_actual_tokens = attn_metadata.num_actual_tokens + + needs_rebuild_block_mask = False + if attn_metadata.sliding_window != self.sliding_window: + attn_metadata.sliding_window = self.sliding_window + if attn_metadata.direct_build: + # update mask mod in attention metadata + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None): + self.mm_prefix_range = attn_metadata.mm_prefix_range + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + if needs_rebuild_block_mask: + if attn_metadata.direct_build and attn_metadata.causal: + attn_metadata.block_mask = attn_metadata._build_block_mask_direct() + else: + attn_metadata.block_mask = attn_metadata.build_block_mask() + + if not attn_metadata.causal: + assert self.attn_type == AttentionType.ENCODER_ONLY + + query, key_tensor, value_tensor = map( + lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), + (query, key, value), + ) + + query = query[:, :, :num_actual_tokens, :] + if (key_tensor.size(-2) > num_actual_tokens) or ( + value_tensor.size(-2) > num_actual_tokens + ): + # In the encoder-only model with torch.compile, + # qkv might be padded, which might cause exception. + # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 + key_tensor = key_tensor[:, :, :num_actual_tokens, :] + value_tensor = value_tensor[:, :, :num_actual_tokens, :] + + else: + assert self.attn_type == AttentionType.DECODER + key_cache, value_cache = kv_cache.unbind(0) + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # View out the block_size dim + key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) + query, key_tensor, value_tensor = map( + lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), + (query, key_cache, value_cache), + ) + + query = query[:, :, :num_actual_tokens, :] + + # Doesn't work for now -> constraint violation + # torch._dynamo.try_mark_dynamic(query, 2) + + assert attn_metadata.block_mask is not None + block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE + + kernel_options = get_kernel_options( + query, block_m, block_n, attn_metadata.direct_build + ) + out = flex_attention_compiled( + query, + key_tensor, + value_tensor, + attn_metadata.transformed_score_mod, + attn_metadata.block_mask, + self.scale, + enable_gqa=enable_gqa, + kernel_options=kernel_options, + ) + + # Flex doesn't have an out variant today, rely on epilogue fusion + out = out.permute(0, 2, 1, 3).squeeze(0) + output[:num_actual_tokens, :, :].copy_(out) + return output + + +def get_kernel_options( + query, block_m, block_n, use_direct_build: bool +) -> dict[str, int | bool]: + kernel_options: dict[str, int | bool] = { + "FORCE_USE_FLEX_ATTENTION": True, + } + + def ensure_divisible(candidate: int, block_size: int) -> int: + """Pick a kernel block size that divides the logical block.""" + if block_size <= 0: + return candidate + candidate = min(candidate, block_size) + if candidate <= 0: + return block_size + if block_size % candidate == 0: + return candidate + + candidate = math.gcd(candidate, block_size) + if candidate <= 1: + return block_size + return candidate + + if vllm_is_batch_invariant(): + kernel_options["BLOCK_M"] = 16 + kernel_options["BLOCK_N"] = 16 + kernel_options["IS_DIVISIBLE"] = False + return kernel_options + if use_direct_build: + kernel_options["BLOCK_M"] = block_m + kernel_options["BLOCK_N"] = block_n + return kernel_options + else: + preferred_block = 32 if query.dtype == torch.float32 else 64 + block_lower_bound = 16 + + block_m_candidate = ensure_divisible(preferred_block, block_m) + block_n_candidate = ensure_divisible(preferred_block, block_n) + + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties() + # ROCm doesn't expose shared_memory_per_block_optin attribute + # AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup + if hasattr(device_props, "shared_memory_per_block_optin"): + max_shared_memory = device_props.shared_memory_per_block_optin + elif current_platform.is_out_of_tree(): + # ROCm fallback: use 64KB + max_shared_memory = 65536 + else: + raise RuntimeError( + "Unable to determine shared memory size on this hardware." + ) + + if max_shared_memory < 144 * 1024: + block_m_candidate = ensure_divisible( + max(1, block_m_candidate // 2), block_m + ) + block_n_candidate = ensure_divisible( + max(1, block_n_candidate // 2), block_n + ) + + block_m_candidate = max(block_m_candidate, block_lower_bound) + block_n_candidate = max(block_n_candidate, block_lower_bound) + + kernel_options["BLOCK_M"] = block_m_candidate + kernel_options["BLOCK_N"] = block_n_candidate + + return kernel_options diff --git a/vllm_fl/v1/attention/backends/mla/__init__.py b/vllm_fl/v1/attention/backends/mla/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/v1/attention/backends/mla/common.py b/vllm_fl/v1/attention/backends/mla/common.py new file mode 100644 index 00000000..5f9914c0 --- /dev/null +++ b/vllm_fl/v1/attention/backends/mla/common.py @@ -0,0 +1,1939 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +# MLA Common Components + +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatenated per head + `q_b_proj` is [W_UQ; W_QR] concatenated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Runtime +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([ql_nope, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import ClassVar, Generic, TypeVar + +import torch +from tqdm import tqdm + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + MLAAttentionImpl, +) +from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.common import cp_lse_ag_out_rs + +# -------------------------------------------------------------- +# Note: use Maca's merge_attn_states to get cuda kernel invoked +# -------------------------------------------------------------- +from vllm_fl.attention.ops.merge_attn_states import merge_attn_states +from vllm_fl.attention.utils.fa_utils import get_flash_attn_version +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv, round_down +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_dcp_local_seq_lens, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec + + +class QueryLenSupport(Enum): + """Defines the level of query length support for an attention backend's + decode pipeline. + + - SINGLE_ONLY: Decode pipeline only supports single-token queries + (query_len=1) + - UNIFORM: Decode pipeline supports uniform multi-token queries + (all requests must have same query_len > 1) + - VARLEN: Decode pipeline supports variable-length queries + (mixed query lengths in same batch) + """ + + SINGLE_ONLY = "single_only" + UNIFORM = "uniform" + VARLEN = "varlen" + + +# /------------------------ Metax Modification -------------------------\ +from flash_attn import flash_attn_varlen_func + +is_vllm_fa = False +# \------------------------- Metax Modification -------------------------/ + +try: + from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + + # /------------------------ Metax Modification -------------------------\ + from typing import Any + + cudnn_batch_prefill_with_kv_cache: Any = None # type: ignore + # \------------------------- Metax Modification -------------------------/ + flashinfer_available = True +except ImportError: + BatchPrefillWithRaggedKVCacheWrapper = object + + flashinfer_available = False + + +from vllm.v1.attention.backends.mla.common import logger + +CUDNN_WORKSPACE_SIZE = 12800 + + +class MLACommonBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_builder_cls() -> type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, head_size) + return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + + +@dataclass +class MLACommonPrefillMetadata: + """Prefill Specific Metadata""" + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + workspace: torch.Tensor + token_to_seq: torch.Tensor + chunk_total_token: list[int] + + # for mla DCP + padded_local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allranks: list[list[int]] | None = None + padded_local_cu_seq_lens: torch.Tensor | None = None + cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None + + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: ChunkedContextMetadata | None = None + query_seq_lens: torch.Tensor | None = None + + +@dataclass +class FlashInferPrefillMetadata(MLACommonPrefillMetadata): + prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None + prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field( + default_factory=list + ) + + +@dataclass +class CudnnPrefillMetadata(MLACommonPrefillMetadata): + class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): + seq_lens: torch.Tensor + + cudnn_workspace: torch.Tensor | None = None + + +@dataclass +class MLACommonDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + dcp_tot_seq_lens: torch.Tensor | None + + +D = TypeVar("D", bound=MLACommonDecodeMetadata) + + +@dataclass +class MLACommonMetadata(Generic[D]): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # The dimension of the attention heads + head_dim: int | None = None + + decode: D | None = None + prefill: ( + MLACommonPrefillMetadata + | FlashInferPrefillMetadata + | CudnnPrefillMetadata + | None + ) = None + + def __post_init__(self): + if self.head_dim is not None and not MLACommonBackend.supports_head_size( + self.head_dim + ): + raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") + + +M = TypeVar("M", bound=MLACommonMetadata) +A = TypeVar("A") + + +def use_flashinfer_prefill() -> bool: + # For blackwell default to flashinfer prefill if it's available since + # it is faster than FA2. + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + return ( + not vllm_config.attention_config.disable_flashinfer_prefill + and flashinfer_available + and not vllm_config.attention_config.use_cudnn_prefill + and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill + and current_platform.is_device_capability_family(100) + ) + + +def use_cudnn_prefill() -> bool: + logger.info_once("cudnn prefill is not supported on Maca.") + return False + + +def use_trtllm_ragged_deepseek_prefill() -> bool: + """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + logger.info_once("TRT-LLM ragged DeepSeek prefill is not supported on Maca.") + return False + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # Defines the level of query length support for this backend. + # - SINGLE_ONLY: Only single-token queries (no spec decode support) + # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) + # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) + # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when + # speculative decoding is enabled. + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY + + # The threshold for reordering the batch into decode and prefill requests. + # If > 1, the batch will be reordered such that requests with + # query length <= threshold are classified as decode requests. + # Use `query_len_support` (above) to set this automatically + # when speculative decoding is enabled. + reorder_batch_threshold: int = 1 + + @staticmethod + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + chunked_prefill_workspace_size = min( + # Try for 8 full length request or at least 4 pages per-request + max( + 8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size, + ), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 64 * 1024, + ) + + # Enforce that we enough for at least 1 page per request + chunked_prefill_workspace_size = max( + chunked_prefill_workspace_size, + scheduler_config.max_num_seqs * cache_config.block_size, + ) + + return chunked_prefill_workspace_size + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[M] | None = None, + supports_dcp_with_varlen: bool = False, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) + self.kv_cache_spec = kv_cache_spec + scheduler_config = vllm_config.scheduler_config + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + self.device = device + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.aot_schedule = current_platform.is_cuda_alike() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size + self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size + + # Don't try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.kv_cache_spec.block_size + + self.chunked_prefill_workspace_size = ( + self.determine_chunked_prefill_workspace_size(vllm_config) + ) + + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + ( + self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size(), + ), + dtype=self.model_config.dtype, + device=device, + ) + else: + self.chunked_prefill_workspace = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.model_config.get_head_size(), + ), + dtype=self.model_config.dtype, + device=device, + ) + + self._use_cudnn_prefill = use_cudnn_prefill() + self._use_fi_prefill = use_flashinfer_prefill() + self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill() + self.prefill_metadata_cls = ( + FlashInferPrefillMetadata + if self._use_fi_prefill + else CudnnPrefillMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata + ) + + if self._use_fi_prefill: + self._workspace_buffer = torch.empty( + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + + self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None + self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] + + self._global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + ) + + if self._use_trtllm_ragged_prefill: + self._workspace_buffer = torch.empty( + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + + if self._use_cudnn_prefill: + self.cudnn_workspace = torch.empty( + CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, + dtype=torch.int8, + device=device, + ) + + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen + ) + + # Validate consistency between query_len_support and reorder_batch_threshold + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: + assert self.reorder_batch_threshold == 1, ( + f"reorder_batch_threshold must be 1 when query_len_support is " + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): + qo_indptr = prefill.query_start_loc + + has_context = False + if prefill.chunked_context is not None: + chunked_context = prefill.chunked_context + has_context = True + + if self._fi_prefill_main is None: + self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass" + ) + + if has_context: + num_chunks = chunked_context.cu_seq_lens.shape[0] + # Allocate more prefill chunk wrappers if needed + if len(self._fi_prefill_chunks) < num_chunks: + for _ in range(len(self._fi_prefill_chunks), num_chunks): + self._fi_prefill_chunks.append( + BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass" + ) + ) + assert num_chunks <= len(self._fi_prefill_chunks) + + # In MLA, the non-latent num_qo_heads == num_kv_heads + num_qo_heads = self.num_heads + num_kv_heads = num_qo_heads + + # Sanity: Verify that num_kv_heads == 1 since it is latent space + assert self.kv_cache_spec.num_kv_heads == 1 + + # Get non-latent head_dim_qk and head_dim_vo + head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim + head_dim_vo = self.mla_dims.v_head_dim + + # For main run, qo_indptr == kv_indptr + kv_indptr = qo_indptr.clone() + + # Prepare main prefill + self._fi_prefill_main.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=True, # This is main run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, + q_data_type=self.model_config.dtype, + ) + + # Prepare context prefills + if has_context: + for i in range(num_chunks): + kv_indptr_chunk = chunked_context.cu_seq_lens[i] + + self._fi_prefill_chunks[i].plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_chunk, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=False, # This is context run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, + q_data_type=self.model_config.dtype, + ) + + prefill.prefill_main = self._fi_prefill_main + prefill.prefill_chunks = self._fi_prefill_chunks + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ) -> MLACommonDecodeMetadata: + return MLACommonDecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with MLA. + """ + m = common_attn_metadata + assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + "MLA only supports decode-only full CUDAGraph capture. " + "Make sure all cudagraph capture sizes <= max_num_seq." + ) + + assert m.max_query_len <= self.reorder_batch_threshold # decode only + + return self.build(0, m) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), + ) + ) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) + + chunked_context_metadata = None + if max_context_len_cpu > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) + + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * max_context_chunk + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + chunk_total_token = cu_seq_lens_cpu[:, -1] + + max_token_num_over_chunk = chunk_total_token.max().item() + token_to_seq_tensor_cpu = torch.zeros( + [num_chunks, max_token_num_over_chunk], dtype=torch.int32 + ) + range_idx = torch.arange(num_prefills, dtype=torch.int32) + for i in range(num_chunks): + chunk_token_to_seq_tensor = torch.repeat_interleave( + range_idx, chunk_seq_lens[i] + ) + chunk_len = chunk_token_to_seq_tensor.shape[0] + token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor + + if self.dcp_world_size > 1: + local_context_lens_allranks = get_dcp_local_seq_lens( + context_lens_cpu, + self.dcp_world_size, + None, + self.dcp_local_block_size, + ) + # Note(qcs): The max local context lengths + # padded to `dcp_local_block_size`. + padded_local_context_lens_cpu = ( + cdiv( + context_lens_cpu, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + padded_local_max_context_chunk_across_ranks = ( + cdiv( + max_context_chunk, + self.dcp_virtual_block_size, + ) + * self.dcp_local_block_size + ) + local_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * padded_local_max_context_chunk_across_ranks + ) + local_chunk_ends = torch.min( + padded_local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + + padded_local_max_context_chunk_across_ranks, + ) + padded_local_chunk_seq_lens = ( + local_chunk_ends - local_chunk_starts + ).clamp(min=0) + + padded_local_cu_chunk_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + padded_local_chunk_seq_lens, + dim=1, + out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + + chunked_context_metadata_cls = ( + CudnnPrefillMetadata.ChunkedContextMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata.ChunkedContextMetadata + ) + if self.dcp_world_size > 1: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token.tolist(), + workspace=self.chunked_prefill_workspace, + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), + local_context_lens_allranks=local_context_lens_allranks.tolist(), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, + ) + else: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token, + workspace=self.chunked_prefill_workspace, + ) + + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens + + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) + + prefill_metadata = self.prefill_metadata_cls( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + ) + + if self._use_cudnn_prefill: + assert isinstance(prefill_metadata, CudnnPrefillMetadata) + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) + prefill_metadata.cudnn_workspace = self.cudnn_workspace + + if self._use_trtllm_ragged_prefill: + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) + + decode_metadata = None + if num_decodes > 0: + dcp_tot_seq_lens_device = None + if self.dcp_world_size > 1: + dcp_tot_seq_lens_device = seq_lens[:num_decodes] + seq_lens_cpu = dcp_local_seq_lens_cpu + seq_lens = dcp_local_seq_lens + + decode_metadata = self._build_decode( + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], + num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, + ) + + attn_metadata = self.metadata_cls( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, + num_actual_tokens=num_tokens, + query_start_loc=query_start_loc, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + if self._use_fi_prefill and num_prefills > 0: + assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) + self._build_fi_prefill_wrappers(attn_metadata.prefill) + + return attn_metadata + + +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + padded_local_chunk_seq_lens_lst: list[int], + local_context_lens_allranks: list[list[int]], + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg and unpad kvcache after cp local gather to tp layout for attn kernel. + e.g. + allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ..., + T0_4, T0_5, pad, pad, T1_2, pad, ...] + -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5, + T1_0, T1_1, T1_2, ...] + Args: + padded_local_chunk_seq_lens_lst: local chunk context lengths + under current CP rank. + local_context_lens_allranks: local context lengths on each CP rank. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for padded_local_chunk_seq_len, local_context_lens in zip( + padded_local_chunk_seq_lens_lst, local_context_lens_allranks + ): + cur_seq_len = 0 + for rank, local_context_len in enumerate(local_context_lens): + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, + ) + if local_chunk_len != 0: + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + local_chunk_len + ] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += local_chunk_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += padded_local_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + +# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, +# and MLACommonImpl -> MLACommonDenseImpl or somthing like that +class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + q_lora_rank: int | None, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + indexer=None, + q_pad_num_heads: int | None = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported for MLA") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.kv_b_proj = kv_b_proj + self.indexer = indexer + self.q_pad_num_heads = q_pad_num_heads + self.is_aiter_triton_fp8_bmm_enabled = False + + def process_weights_after_loading(self, act_dtype: torch.dtype): + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + if current_platform.is_out_of_tree(): + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if current_platform.is_out_of_tree(): + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result + + +class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if use_flashinfer_prefill(): + logger.debug_once("Using FlashInfer prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi + self._pad_v = False + elif use_trtllm_ragged_deepseek_prefill(): + logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA") + self._run_prefill_context_chunk = ( + self._run_prefill_context_chunk_trtllm_ragged + ) + self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged + self._pad_v = False + elif use_cudnn_prefill(): + logger.debug_once("Using CUDNN prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn + self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn + self._pad_v = False + else: # Use FlashAttention + logger.debug_once("Using FlashAttention prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa + + # Handle the differences between the flash_attn_varlen from + # flash_attn and the one from vllm_flash_attn. The former is used on + # RoCM and the latter has an additional parameter to control + # FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = functools.partial( + flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + ) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9 + ) + + self.dcp_world_size: int | None = None + + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + self.cp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) + + if is_vllm_fa: + kwargs["return_softmax_lse"] = return_softmax_lse + else: + # ROCm leverages the upstream flash_attn, which takes a parameter + # called "return_attn_probs" instead of return_softmax_lse + kwargs["return_attn_probs"] = return_softmax_lse + if vllm_is_batch_invariant(): + kwargs["num_splits"] = 1 + + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + + def _run_prefill_new_tokens_fa( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) + + def _run_prefill_new_tokens_fi( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + assert isinstance(prefill, FlashInferPrefillMetadata) + assert prefill.prefill_main is not None + + ret = prefill.prefill_main.run( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_new_tokens_cudnn( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + raise NotImplementedError( + "CUDNN prefill for new tokens is not supported on MACA. " + ) + + def _run_prefill_context_chunk_fa( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + def _run_prefill_context_chunk_fi( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + assert isinstance(prefill, FlashInferPrefillMetadata) + + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( + q=q, + k=k, + v=v, + return_lse=True, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + + def _run_prefill_context_chunk_cudnn( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + raise NotImplementedError( + "CUDNN prefill for new tokens is not supported on MACA. " + ) + + def _run_prefill_new_tokens_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + """TRT-LLM ragged attention for new tokens (causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.query_seq_lens is not None + + ret = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.query_seq_lens, + max_q_len=prefill.max_query_len, + max_kv_len=prefill.max_query_len, + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.query_seq_lens.shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.query_start_loc, + enable_pdl=False, + is_causal=True, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_context_chunk_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + """TRT-LLM ragged attention for context chunks (non-causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + + out = torch.zeros( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=q.dtype, + ) + self._workspace_buffer.fill_(0) + + attn_out, lse = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.chunked_context.seq_lens[chunk_idx], + max_q_len=prefill.max_query_len, + max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], + enable_pdl=False, + is_causal=False, + return_lse=True, + out=out, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + + def process_weights_after_loading(self, act_dtype: torch.dtype): + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + if current_platform.is_out_of_tree(): + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _concat_k_nope_k_pe( + self, k_nope: torch.Tensor, k_pe: torch.Tensor + ) -> torch.Tensor: + """ + Efficiently concatenate k_nope and k_pe tensors along the last dimension. + + This function avoids the performance penalty of torch.cat with expanded + non-contiguous tensors by pre-allocating the output and using direct copies. + + Args: + k_nope: Tensor of shape [..., nope_dim] + k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] + or [..., pe_dim] + + Returns: + Tensor of shape [..., nope_dim + pe_dim] + """ + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + # Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + ): + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.gather_and_maybe_dequant_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], + num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = self._concat_k_nope_k_pe(k_nope, k_pe) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _context_parallel_compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + dcp_world_size: int, + ): + assert k_scale is None, "DCP not support scaled kvcache now." + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.local_context_lens_allranks is not None + assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + assert prefill_metadata.chunked_context.chunk_size is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[ + i + ], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[ + i + ], + local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks, + ) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + # TODO (zyongye): Prefill function here + assert attn_metadata.prefill is not None + assert self.dcp_world_size is not None + + has_context = attn_metadata.prefill.chunked_context is not None + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = self._concat_k_nope_k_pe(k_nope, k_pe) + + output_prefill = self._run_prefill_new_tokens( + prefill=attn_metadata.prefill, + q=q, + k=k, + v=v, + return_softmax_lse=has_context, + ) + + if has_context: + suffix_output, suffix_lse = output_prefill + if self.dcp_world_size > 1: + context_output, context_lse = ( + self._context_parallel_compute_prefill_context( + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) + else: + context_output, context_lse = self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) + + # unpad if necessary + if self._pad_v: + context_output = context_output[..., : v.shape[-1]] + suffix_output = suffix_output[..., : v.shape[-1]] + + output = output.view(-1, self.num_heads, self.v_head_dim) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + else: + output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) + output.copy_(output_prefill) + + @abstractmethod + def _forward_decode( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: M, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLACommonImpl" + ) + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + self._forward_prefill( + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + output=output[num_decode_tokens:], + ) + + if has_decode: + assert attn_metadata.decode is not None + + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if current_platform.is_out_of_tree(): + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L) + ) + decode_ql_nope.resize_((N, B, L)) + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if fp8_attention: + raise NotImplementedError( + "FP8 attention is not supported for MACA now." + ) + else: + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer + ) + + # correct dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # v_up projection + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + return output_padded diff --git a/vllm_fl/v1/attention/backends/mla/flashmla.py b/vllm_fl/v1/attention/backends/mla/flashmla.py new file mode 100644 index 00000000..bdd6c862 --- /dev/null +++ b/vllm_fl/v1/attention/backends/mla/flashmla.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm_fl.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_dense_supported, +) +from vllm.config import VllmConfig +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms.interface import DeviceCapability +from vllm_fl.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + QueryLenSupport, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, +) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +logger = init_logger(__name__) + + +@register_backend(AttentionBackendEnum.FLASHMLA) +class MacaFlashMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [64] + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLAImpl"]: + return FlashMLAImpl + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if use_sparse: + from vllm_fl.attention.ops.flashmla import is_flashmla_sparse_supported + + return is_flashmla_sparse_supported()[1] + else: + from vllm_fl.attention.ops.flashmla import is_flashmla_dense_supported + + return is_flashmla_dense_supported()[1] + + +@dataclass +class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: torch.Tensor + num_splits: torch.Tensor + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): + pass + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + reorder_batch_threshold: int = 128 # process small prefills with decode pathway + # ^ TODO(matt): tune this + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata + ) + + self.num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None + self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") + + device_properties = torch.cuda.get_device_properties(self.device) + num_sms = device_properties.multi_processor_count + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.cg_buf_tile_scheduler_metadata = torch.zeros( + # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) + # TileSchedulerMetaDataSize = 8 + (num_sms, 8), + device=self.device, + dtype=torch.int32, + ) + self.cg_buf_num_splits = torch.empty( + (vllm_config.scheduler_config.max_num_seqs + 1), + device=self.device, + dtype=torch.int32, + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ) -> FlashMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + # we use the max but all should be the same due to uniform length requirement + max_query_len = query_lens_cpu.max().item() + num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1 + tile_scheduler_metadata, num_splits = get_mla_metadata( + seq_lens_device, + num_q_tokens_per_head_k, + 1, # MQA for the decode path + is_fp8_kvcache=self.is_fp8_kvcache, + ) + + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None + + sm_parts = tile_scheduler_metadata.size(0) + # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) + tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[ + :sm_parts + ] + tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) + tile_scheduler_metadata = tile_scheduler_metadata_view + + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + # Num splits needs to monotonically increasing + # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise + # it needs to monotonically increasing by 1) + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + num_splits = num_splits_view + + return FlashMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + ) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + is_supported, reason = is_flashmla_dense_supported() + assert is_supported, reason + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap" + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl" + ) + + def _forward_decode( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO: (zyongye) decode function for mla here + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if type(q) is tuple: + q = torch.cat(q, dim=-1) + + # mypy assertion: q is now always a tensor + assert isinstance(q, torch.Tensor) + + num_decodes = attn_metadata.num_decodes + q = reshape_query_for_spec_decode(q, num_decodes) + + tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata + num_splits = attn_metadata.decode.num_splits + if vllm_is_batch_invariant(): + device = q.device + dtype = torch.int32 + + B = q.shape[0] + # block_table shape: [batch_size, max_num_blocks_per_seq] + # The number of blocks per sequence is in the second dimension + topk = attn_metadata.decode.block_table.shape[-1] + B_TOPK = 64 + assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}" + end_block_idx = topk // B_TOPK + + # Single partition => num_sm_parts = 1 + # TileSchedulerMetaDataSize = 8, layout: + # [begin_idx, begin_block_idx, end_idx, end_block_idx, + # begin_n_split_idx, _, _, _] + tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device) + tile_scheduler_metadata[0, 0] = 0 # begin_idx + tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx + tile_scheduler_metadata[0, 2] = B - 1 # end_idx + tile_scheduler_metadata[0, 3] = end_block_idx + tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx + # fields [5..7] stay 0 + + # Non-split path ignores num_splits, but the API requires it: + # zeros of length B+1 + num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) + + o, lse = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=attn_metadata.decode.block_table, + cache_seqlens=attn_metadata.decode.seq_lens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=self.scale, + causal=True, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), + ) + + o = reshape_attn_output_for_spec_decode(o) + + return o, lse diff --git a/vllm_fl/v1/attention/backends/mla/flashmla_sparse.py b/vllm_fl/v1/attention/backends/mla/flashmla_sparse.py new file mode 100644 index 00000000..fc96dbc7 --- /dev/null +++ b/vllm_fl/v1/attention/backends/mla/flashmla_sparse.py @@ -0,0 +1,1025 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + MultipleOf, +) +from vllm.attention.backends.utils import get_mla_dims +from vllm_fl.attention.ops.flashmla import ( + flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata, +) +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv +from vllm_fl.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, + split_decodes_and_prefills, + split_prefill_chunks, +) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer + +logger = init_logger(__name__) + +# For FP8 sparse attention we have two impelementations: +# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is +# done by treating all tokens as single batch. +# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill +# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using +# the FP8 decode kernel for decode. +# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16 +# prefill kernel requires padding the numer of heads to 128 while the decode does not +# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed +# batch mode (#2). +MIN_HEADS_FOR_BF16_PREFILL = 32 + +""" +NOTE: FlashMLA Sparse uses an fp8 cache with the following format + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 + `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, + the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This + part is not quantized for accuracy. +""" + + +@register_backend(AttentionBackendEnum.FLASHMLA_SPARSE) +class MacaFlashMLASparseBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [64] + + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE" + + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) + + +@dataclass +class FlashMLASparseMetadata: + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + block_size: int = 64 + topk_tokens: int = 2048 + + @dataclass + class FP8KernelMetadata: + scheduler_metadata: torch.Tensor | None + num_splits: torch.Tensor + dummy_block_table: torch.Tensor + cache_lens: torch.Tensor + + @dataclass + class FP8SeperatePrefillDecode: + @dataclass + class Decode: + kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" + decode_query_len: int # needed for reshape in spec decode + + @dataclass + class Prefill: + # Sequence lengths (context + query) for prefill requests + # Shape: [num_prefill_reqs] + seq_lens: torch.Tensor + + # Request ID for each token: -1 for decode tokens, request index + # (0, 1, 2, ...) for prefill tokens. + # Shape: [num_actual_tokens] + request_ids: torch.Tensor + + # Workspace start offsets for all prefill requests + # Shape: [num_prefill_reqs], adjusted in-place per chunk to be + # 0-indexed within each chunk. Used to map prefill tokens to workspace + # offsets in convert_logical_index_to_physical_index + workspace_starts: torch.Tensor + + @dataclass + class Chunk: + """Metadata for a chunk of prefill requests. + + Prefill requests may be chunked to fit within the fixed workspace size. + """ + + seq_lens: torch.Tensor + tokens_slice: slice + block_table: torch.Tensor + req_start_idx: int + workspace_starts: torch.Tensor + chunk_tot_seqlen: int + + chunks: list[Chunk] + + num_prefills: int = 0 + num_decodes: int = 0 + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + + decode: Decode | None = None + prefill: Prefill | None = None + + fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None + fp8_use_mixed_batch: bool = False + + +# Kernel with prefill workspace support +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + is_prefill = False + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + is_invalid_tok |= ~valid_block + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + out_val = base * BLOCK_SIZE + inblock_off + + # Override with prefill output if prefill is enabled + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + out_val = tl.where(is_prefill, prefill_out, out_val) + out_val = tl.where(is_invalid_tok, -1, out_val) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns + HAS_PREFILL_WORKSPACE: bool = False, + prefill_workspace_request_ids: torch.Tensor | None = None, + prefill_workspace_starts: torch.Tensor | None = None, +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + + When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets + instead of global cache slots. prefill_workspace_request_ids and + prefill_workspace_starts must be provided. + + prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else + prefill request index (maps to prefill_workspace_starts) + prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace + starts for each prefill request + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + ) + + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None + assert prefill_workspace_starts is not None + assert prefill_workspace_request_ids.dtype == torch.int32 + assert prefill_workspace_starts.dtype == torch.int32 + + num_tokens = req_id.shape[0] + max_num_blocks_per_req = block_table.shape[1] + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Prepare prefill pointers + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None # for mypy + assert prefill_workspace_starts is not None # for mypy + assert prefill_workspace_request_ids.is_contiguous() + assert prefill_workspace_starts.is_contiguous() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + prefill_workspace_request_ids, + prefill_workspace_starts, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + HAS_PREFILL_WORKSPACE, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + +def get_prefill_workspace_size(max_model_len: int): + # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. + # May be tuned later. + # Memory usage: 5 * max_model_len * 576 * 2 bytes + # Example: DeepSeek-V3.2 with max_model_len=163840 -> + # 5 * 163840 * 576 * 2 = ~900 MB + # This fits nicely below the typical MoE workspace size of >2GB so this is "free" + return max_model_len * 5 + + +class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): + # /------------------------ Metax Modification -------------------------\ + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + # \------------------------ Metax Modification -------------------------/ + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ) -> None: + self.vllm_config = vllm_config + self.layer_names = layer_names + cache_config = vllm_config.cache_config + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + + # Treat requests with query length <= 1 as decodes to match the + # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG) + self.topk_tokens_tensor = torch.full( + (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32 + ) + # Shape: [max_num_seqs], all elements = max_model_len + self.max_model_len_tensor = torch.full( + (max_num_seqs,), + self.model_config.max_model_len, + device=device, + dtype=torch.int32, + ) + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty( + (max_num_seqs, 1), dtype=torch.int32, device=self.device + ) + + # Equation taken from FlashMLA/csrc/pybind.cpp + h_q, h_k = self.num_heads, 1 + s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest + max_num_sm_parts = int( + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) + ) + if current_platform.is_device_capability_family(100): + max_num_sm_parts *= 2 + self.tile_scheduler_metadata_buffer = torch.empty( + # TileSchedulerMetaDataSize = 8 + # see: FlashMLA/csrc/params.h + (max_num_sm_parts, 8), + dtype=torch.int32, + device=device, + ) + # Sized for per-request batching (num_decodes + 1) + self.num_splits_buffer = torch.empty( + (max_num_seqs + 1,), + dtype=torch.int32, + device=device, + ) + self.req_id_per_token_buffer = torch.empty( + (vllm_config.scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=device, + ) + + def _build_fp8_mixed_decode_prefill( + self, + common_attn_metadata: CommonAttentionMetadata, + ) -> "FlashMLASparseMetadata.FP8KernelMetadata": + """Build FP8 metadata treating all tokens as one mixed batch. + + This matches main branch's approach and avoids the BF16 prefill kernel + which has head padding overhead when num_heads is small (high TP case). + """ + num_tokens = common_attn_metadata.num_actual_tokens + + # Build metadata for all tokens as a single batch + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor[:1], # Single batch + num_q_tokens_per_head_k=num_tokens * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + + num_sm_parts = tile_scheduler_metadata.size(0) + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + num_splits_view = self.num_splits_buffer[:2] + num_splits_view.copy_(num_splits) + + fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=num_splits_view, + cache_lens=self.max_model_len_tensor[:1], + dummy_block_table=self.dummy_block_table[:1], + ) + + return fp8_metadata + + def _build_fp8_separate_prefill_decode( + self, + common_attn_metadata: CommonAttentionMetadata, + ) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": + num_tokens = common_attn_metadata.num_actual_tokens + + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold or 1, + require_uniform=True, + ) + ) + + FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode + fp8_metadata = FP8Meta( + num_decodes=num_decodes, + num_prefills=num_prefills, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + ) + + # Extract prefill sequence lengths (context + query, not just query) + # Decode requests come first in the batch, prefill requests follow + prefill_seq_lens = None + prefill_request_id = None + prefill_workspace_starts = None + prefill_chunks = None + + # For pure decode batches, prefill_request_id will be None + # For mixed batches, it will have -1 for decode and request_id for prefill + if num_prefills > 0: + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:] + prefill_seq_lens = seq_lens[num_decodes:] + + # Build prefill_request_id: -1 for decode, request index for + # prefill. This enables a single + # convert_logical_index_to_physical_index call for all tokens + prefill_request_id = torch.full( + (num_tokens,), -1, dtype=torch.int32, device=self.device + ) + # Map prefill tokens to their request IDs (0, 1, 2, ...) + for req_idx in range(num_prefills): + # Get query token range for this prefill request + global_req_idx = num_decodes + req_idx + req_query_start = query_start_loc_cpu[global_req_idx] + req_query_end = query_start_loc_cpu[global_req_idx + 1] + prefill_request_id[req_query_start:req_query_end] = req_idx + + # will be adjusted by chunk loop + prefill_workspace_starts_cpu = torch.zeros( + num_prefills, dtype=torch.int32, pin_memory=True + ) + prefill_workspace_starts_cpu[1:] = torch.cumsum( + prefill_seq_lens_cpu[:-1], dim=0 + ) + # populated by non-blocking copy after prefill_workspace_starts_cpu is + # updated by each chunk + prefill_workspace_starts = torch.empty( + num_prefills, dtype=torch.int32, device=self.device + ) + + # Chunk prefill requests to fit within workspace size + max_prefill_buffer_size = get_prefill_workspace_size( + self.vllm_config.model_config.max_model_len + ) + chunk_bounds = split_prefill_chunks( + prefill_seq_lens_cpu, max_prefill_buffer_size + ) + + prefill_chunks = [] + for chunk_start, chunk_end in chunk_bounds: + # Adjust workspace_starts in-place per chunk to be + # 0-indexed within each chunk + # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] + # Initial: workspace_starts=[0,10,25,45] + # After: workspace_starts=[0,10,0,20] + # (chunk 0 starts at 0, chunk 1 starts at 0) + offset = prefill_workspace_starts_cpu[chunk_start].item() + prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset + + chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] + chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum() + token_start = query_start_loc_cpu[num_decodes + chunk_start].item() + token_end = query_start_loc_cpu[num_decodes + chunk_end].item() + tokens_slice = slice(token_start, token_end) + + # Create chunk view of gpu tensor + chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] + chunk_block_table = common_attn_metadata.block_table_tensor[ + num_decodes + chunk_start : num_decodes + chunk_end + ] + + prefill_chunks.append( + FP8Meta.Prefill.Chunk( + seq_lens=chunk_seq_lens, + tokens_slice=tokens_slice, + block_table=chunk_block_table, + req_start_idx=chunk_start, + workspace_starts=chunk_workspace_starts, + chunk_tot_seqlen=chunk_tot_seqlen, + ) + ) + + prefill_workspace_starts.copy_( + prefill_workspace_starts_cpu, non_blocking=True + ) + + fp8_metadata.prefill = FP8Meta.Prefill( + seq_lens=prefill_seq_lens, + request_ids=prefill_request_id, + workspace_starts=prefill_workspace_starts, + chunks=prefill_chunks, + ) + + if num_decodes > 0: + # Compute decode_query_len for spec decode (uniform due to require_uniform) + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item() + + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor[:num_decodes], + num_q_tokens_per_head_k=decode_query_len * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + + num_sm_parts = tile_scheduler_metadata.size(0) + # Copy to persistent buffer for full-CG support + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + # num_splits has size [num_decodes + 1] + num_splits_view = self.num_splits_buffer[: num_decodes + 1] + num_splits_view.copy_(num_splits) + + kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=num_splits_view, + dummy_block_table=self.dummy_block_table[:num_decodes], + cache_lens=self.max_model_len_tensor[:num_decodes], + ) + fp8_metadata.decode = FP8Meta.Decode( + kernel_metadata=kernel_meta, + decode_query_len=decode_query_len, + ) + + return fp8_metadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: + cm = common_attn_metadata + num_tokens = cm.num_actual_tokens + starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + fp8_extra_metadata: ( + FlashMLASparseMetadata.FP8SeperatePrefillDecode + | FlashMLASparseMetadata.FP8KernelMetadata + | None + ) = None + fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL + if self.use_fp8_kv_cache: + if fp8_use_mixed_batch: + fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) + else: + fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm) + + metadata = FlashMLASparseMetadata( + num_reqs=cm.num_reqs, + max_query_len=cm.max_query_len, + max_seq_len=cm.max_seq_len, + num_actual_tokens=cm.num_actual_tokens, + query_start_loc=cm.query_start_loc, + slot_mapping=cm.slot_mapping, + block_table=cm.block_table_tensor, + req_id_per_token=req_id_per_token, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + fp8_extra_metadata=fp8_extra_metadata, + fp8_use_mixed_batch=fp8_use_mixed_batch, + ) + + return metadata + + +class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + topk_indice_buffer: torch.Tensor | None = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer = indexer.topk_indices_buffer + self.padding = 128 if current_platform.is_device_capability_family(100) else 64 + + if kv_cache_dtype == "fp8_ds_mla": + # Reserve workspace during initialization + vllm_config = get_current_vllm_config() + assert vllm_config is not None and vllm_config.model_config is not None + prefill_workspace_size = get_prefill_workspace_size( + vllm_config.model_config.max_model_len + ) + self.prefill_workspace_shape = (prefill_workspace_size, head_size) + (self.prefill_bf16_workspace,) = ( + current_workspace_manager().get_simultaneous( + (self.prefill_workspace_shape, torch.bfloat16) + ) + ) + + def _forward_bf16_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + ) + + return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices) + + def _forward_fp8_kv_separate_prefill_decode( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + fp8_metadata = attn_metadata.fp8_extra_metadata + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + num_decodes = fp8_metadata.num_decodes + + prefill_request_ids = None + prefill_workspace_starts = None + has_prefill_workspace = False + if fp8_metadata.prefill is not None: + prefill_request_ids = fp8_metadata.prefill.request_ids + prefill_workspace_starts = fp8_metadata.prefill.workspace_starts + has_prefill_workspace = True + + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + # For FP8 cache: prefill uses workspace mapping (upconverted to BF16) + # For BF16 cache: always use global cache slots (no workspace) + # prefill_workspace_starts has been adjusted in-place per chunk so + # prefill indices automatically come out chunk-local + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + HAS_PREFILL_WORKSPACE=has_prefill_workspace, + prefill_workspace_request_ids=prefill_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + fp8_metadata = attn_metadata.fp8_extra_metadata + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + + def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: + # Reshape q: (num_decode_tokens, num_heads, head_dim) + # -> (num_decodes, seq_len, num_heads, head_dim) + q = reshape_query_for_spec_decode(q, num_decodes) + seq_len = q.shape[1] + # Reshape topk_indices: (num_decode_tokens, topk) + # -> (num_decodes, seq_len, topk) + topk_indices = topk_indices.view(num_decodes, seq_len, -1) + assert fp8_metadata.decode is not None + attn_out, _ = self._fp8_flash_mla_kernel( + q=q, + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, + topk_indices=topk_indices, + kernel_metadata=fp8_metadata.decode.kernel_metadata, + ) + # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) + # -> (num_decode_tokens, num_heads, head_dim_v) + return reshape_attn_output_for_spec_decode(attn_out) + + num_decode_tokens = fp8_metadata.num_decode_tokens + num_prefill_tokens = fp8_metadata.num_prefill_tokens + + # Pure decode: direct call without allocation + if num_decode_tokens > 0 and num_prefill_tokens == 0: + assert fp8_metadata.decode is not None + attn_out = _fp8_decode(q, topk_indices) + else: + # Mixed or pure prefill: allocate output tensor + attn_out = q.new_empty( + (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank), + dtype=q.dtype, + device=q.device, + ) + + if num_decode_tokens > 0: + attn_out[:num_decode_tokens] = _fp8_decode( + q[:num_decode_tokens], topk_indices[:num_decode_tokens] + ) + + assert fp8_metadata.prefill is not None + for chunk in fp8_metadata.prefill.chunks: + chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen] + ops.cp_gather_and_upconvert_fp8_kv_cache( + kv_c_and_k_pe_cache, + chunk_workspace, + chunk.block_table, + chunk.seq_lens, + chunk.workspace_starts, + len(chunk.block_table), + ) + + chunk_q = q[chunk.tokens_slice] + chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice] + + attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel( + chunk_q, + chunk_workspace, + chunk_topk_indices_workspace, + ) + + return attn_out + + def _forward_fp8_kv_mixed_batch( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + """Mixed batch FP8 forward path that treats all tokens as one batch. + + This is equivalent to main branch's approach and avoids the BF16 + prefill kernel which has head padding overhead when num_heads is small. + Used when use_mixed_batch is True. + """ + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + ) + + assert attn_metadata.fp8_extra_metadata is not None + assert isinstance( + attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata + ) + fp8_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = self._fp8_flash_mla_kernel( + q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D) + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, + topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk) + kernel_metadata=fp8_metadata, + ) + + # Output is (1, T, H, D_v), squeeze back to (T, H, D_v) + return _attn_out.squeeze(0) + + def _fp8_flash_mla_kernel( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata, + ) -> torch.Tensor: + return flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=kernel_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=kernel_metadata.cache_lens, + tile_scheduler_metadata=kernel_metadata.scheduler_metadata, + num_splits=kernel_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices, + softmax_scale=self.softmax_scale, + ) + + def _bf16_flash_mla_kernel( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + ) -> torch.Tensor: + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) + + # NOTE(Chen): kernel requires num_local_head to be a multiple of + # 64 on hopper and 128 on blackwell + if self.num_heads % self.padding != 0: + assert self.padding % self.num_heads == 0 + logger.warning_once( + f"padding num_heads to {self.padding} \ + due to sparse attn kernel requirement" + ) + q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) + q_padded[:, : self.num_heads, :] = q + q = q_padded + + topk_indices = topk_indices.view(num_tokens, 1, -1) + output = flash_mla_sparse_prefill( + q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale + )[0] + output = output[:, : self.num_heads, :] + return output + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata | None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode + + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLACommonImpl" + ) + + if attn_metadata is None: + # Dummy run - no need to allocate buffers + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + + use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" + + q = torch.cat([ql_nope, q_pe], dim=-1) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if not use_fp8_cache: + attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) + elif attn_metadata.fp8_use_mixed_batch: + attn_out = self._forward_fp8_kv_mixed_batch( + q, kv_cache, topk_indices, attn_metadata + ) + else: + attn_out = self._forward_fp8_kv_separate_prefill_decode( + q, kv_cache, topk_indices, attn_metadata + ) + + self._v_up_proj(attn_out, out=output[:num_actual_toks]) + return output diff --git a/vllm_fl/v1/attention/backends/mla/indexer.py b/vllm_fl/v1/attention/backends/mla/indexer.py new file mode 100644 index 00000000..8786e00b --- /dev/null +++ b/vllm_fl/v1/attention/backends/mla/indexer.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + MultipleOf, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata +from vllm_fl.mx_utils.deep_gemm import is_deep_gemm_supported +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, + split_prefill_chunks, +) + +logger = init_logger(__name__) + + +class MacaDeepseekV32IndexerBackend(AttentionBackend): + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [1 if current_platform.is_rocm() else 64] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 128] + + @staticmethod + def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: + return DeepseekV32IndexerMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + assert num_kv_heads == 1 + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + return (0, 1, 2, 3) + return (0, 1, 2) + + +@dataclass +class DeepseekV32IndexerPrefillChunkMetadata: + block_table: torch.Tensor + cu_seqlen_ks: torch.Tensor + cu_seqlen_ke: torch.Tensor + cu_seq_lens: torch.Tensor + total_seq_lens: int + token_start: int + token_end: int + num_reqs: int + + +@dataclass +class DeepseekV32IndexerPrefillMetadata: + chunks: list[DeepseekV32IndexerPrefillChunkMetadata] + + +@dataclass +class DeepSeekV32IndexerDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + decode_lens: torch.Tensor + requires_padding: bool + schedule_metadata: torch.Tensor + + +@dataclass +class DeepseekV32IndexerMetadata: + # FIXME (zyongye) + # hacky way to access the data now, need to be in chunked meta + seq_lens: torch.Tensor + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + # The dimension of the attention heads + head_dim: int + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + decode: DeepSeekV32IndexerDecodeMetadata | None = None + prefill: DeepseekV32IndexerPrefillMetadata | None = None + + +# TODO (zyongye) optimize this, this is now vibe coded +def kv_spans_from_batches( + start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + start_seq_loc: 1D long tensor [B+1], cumulative counts of + selected tokens per batch. + Example: [0, 2, 4, 7] -> + batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], + full sequence length (KV length) of each batch. + Example: [5, 9, 4]. + + Returns: + start_tensor: 1D long tensor [N], start offset in the + concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], + **exclusive** end = start + token's local position. + (So the attended KV slice is kv[start:end].) + + Assumes each batch contributes its full `seq_len_per_batch[i]` + keys to the KV cache, andthe selected tokens within a batch + are the **last** `counts[i]` positions of that sequence. + """ + q = start_seq_loc.to(dtype=torch.long) + L = seq_len_per_batch.to(dtype=torch.long) + assert q.dim() == 1 and L.dim() == 1 + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" + + # Selected tokens per batch and totals + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens + B = L.numel() + + if N == 0: + return ( + torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) + + # KV start offsets per batch in the concatenated KV cache + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + + # For each selected token, which batch does it belong to? + batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N] + + # Map batch KV start to each token + start_tensor = kv_starts_per_batch[batch_id] # [N] + + # End-align local positions inside each batch: + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] + # position within the selected block: 1..counts[b] + pos_within = ( + torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1 + ) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end + + return start_tensor.int().to(device), end_location.int().to(device) + + +def get_max_prefill_buffer_size(vllm_config: VllmConfig): + max_model_len = vllm_config.model_config.max_model_len + # NOTE(Chen): 40 is a magic number for controlling the prefill buffer size. + # Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes. + # The flashmla_sparse backend uses a workspace size of 5 * max_model_len. + # The memory usage of the workspace there is 576 * 2 bytes; so we size this as + # (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting + # within the flashmla_sparse workspace. + # For DeepSeek-V3.2, the max_model_len is 163840. + # 40 * 163840 * 132 = 865075200 bytes = 825 MB + return max_model_len * 40 + + +class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): + _cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + + reorder_batch_threshold: int = 1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + scheduler_config = self.vllm_config.scheduler_config + # NOTE(Chen):an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) + self.num_speculative_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config + else 0 + ) + # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 + self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) + + props = torch.cuda.get_device_properties(self.device) + sm_count = props.multi_processor_count + self.num_sms = sm_count + + self.decode_lens_buffer = torch.empty( + (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device + ) + + # See: DeepGMM/csrc/apis/attention.hpp + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) + + def build_one_prefill_chunk( + self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table + ): + prefill_query_start_loc = ( + query_start_loc_cpu[reqs_start : reqs_end + 1] + - query_start_loc_cpu[reqs_start] + ) + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + ) + token_start = query_start_loc_cpu[reqs_start].item() + token_end = query_start_loc_cpu[reqs_end].item() + total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + assert total_seq_lens <= self.max_prefill_buffer_size + cu_seq_lens = ( + torch.cat( + [ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + ] + ) + .to(torch.int32) + .to(self.device) + ) + return DeepseekV32IndexerPrefillChunkMetadata( + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + cu_seq_lens=cu_seq_lens, + total_seq_lens=total_seq_lens, + block_table=block_table[reqs_start:reqs_end], + token_start=token_start, + token_end=token_end, + num_reqs=reqs_end - reqs_start, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> DeepseekV32IndexerMetadata: + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + chunk_seq_ids = split_prefill_chunks( + common_attn_metadata.seq_lens_cpu[num_decodes:], + self.max_prefill_buffer_size, + request_offset=num_decodes, + ) + chunks = [ + self.build_one_prefill_chunk( + reqs_start, + reqs_end, + query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu, + common_attn_metadata.block_table_tensor, + ) + for reqs_start, reqs_end in chunk_seq_ids + ] + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + chunks=chunks, + ) + + decode_metadata = None + if num_decodes > 0: + torch.diff( + common_attn_metadata.query_start_loc[: num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes], + ) + decode_lens = self.decode_lens_buffer[:num_decodes] + decode_lens_cpu = torch.diff( + common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] + ) + + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + if is_deep_gemm_supported(): + self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( + seq_lens, self.kv_cache_spec.block_size, self.num_sms + ) + decode_metadata = DeepSeekV32IndexerDecodeMetadata( + block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + decode_lens=decode_lens, + requires_padding=requires_padding, + schedule_metadata=self.scheduler_metadata_buffer, + ) + + attn_metadata = DeepseekV32IndexerMetadata( + seq_lens=common_attn_metadata.seq_lens, + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + head_dim=128, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + # if get_tensor_model_parallel_rank() == 0: + # logger.info(f"attn_metadata: {attn_metadata}") + return attn_metadata diff --git a/vllm_fl/v1/attention/backends/mla/triton_mla.py b/vllm_fl/v1/attention/backends/mla/triton_mla.py new file mode 100644 index 00000000..dbb977f4 --- /dev/null +++ b/vllm_fl/v1/attention/backends/mla/triton_mla.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) + +# --------------------------------------------------------------- +# Note: We need to use Maca's decode_attention due to +# the limit of shmem for triton kernel +# --------------------------------------------------------------- +from vllm_fl.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms.interface import DeviceCapability +from vllm_fl.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +logger = init_logger(__name__) + + +@register_backend(AttentionBackendEnum.TRITON_MLA) +class MacaTritonMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> type["TritonMLAImpl"]: + return TritonMLAImpl + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + + +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap" + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl" + ) + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA V1 with FP8 KV cache not yet supported" + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + return super()._flash_attn_varlen_diff_headdims( + q, + k, + v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + def _forward_decode( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Triton MLA not yet supported") + + if type(q) is tuple: + q = torch.cat(q, dim=-1) + + assert isinstance(q, torch.Tensor) + B = q.shape[0] + q_num_heads = q.shape[1] + o = torch.zeros( + B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) + lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) + + # For batch invariance, use only 1 split to ensure deterministic reduction + num_kv_splits = 1 if vllm_is_batch_invariant() else 4 + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + q_num_heads, + num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # Run MQA + decode_attention_fwd( + q, + kv_c_and_k_pe_cache, + kv_c_cache, + o, + lse, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, + attn_logits, + num_kv_splits, + self.scale, + PAGE_SIZE, + ) + + return o, lse diff --git a/vllm_fl/v1/attention/backends/tree_attn.py b/vllm_fl/v1/attention/backends/tree_attn.py new file mode 100644 index 00000000..0f6d2964 --- /dev/null +++ b/vllm_fl/v1/attention/backends/tree_attn.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with TreeAttention.""" + +import ast +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) +from vllm_fl.attention.ops.triton_unified_attention import unified_attention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +logger = init_logger(__name__) + + +@register_backend(AttentionBackendEnum.TREE_ATTN) +class MacaTreeAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "TREE_ATTN" + + @staticmethod + def get_impl_cls() -> type["TreeAttentionImpl"]: + return TreeAttentionImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]: + return TreeAttentionMetadataBuilder + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +@dataclass +class TreeAttentionMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 + num_decodes: int = 0 + + tree_attn_bias: torch.Tensor | None = None + + # Cached Prefill/decode metadata. + _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None + _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + q_start_loc = self.query_start_loc[self.num_decodes :] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[self.num_decodes :] + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_prefill_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc - q_start_loc[0], + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + + q_start_loc = self.query_start_loc[: self.num_decodes + 1] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[: self.num_decodes] + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_decode_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc, + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], + tree_attn_bias=self.tree_attn_bias, + ) + return self._cached_decode_metadata + + +class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + self.block_size = kv_cache_spec.block_size + + spec_config = vllm_config.speculative_config + spec_token_tree = (spec := spec_config) and spec.speculative_token_tree + tree_choices: list[tuple[int, ...]] = ( + ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] + ) + # Construct the tree attention bias. + depth_counts = _get_depth_counts(tree_choices) + self.tree_attn_bias = _prepare_tree_attn_bias( + tree_choices, + depth_counts, + dtype=torch.float32, + device=device, + ) + + self.reorder_batch_threshold = self.tree_attn_bias.shape[0] + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TreeAttentionMetadata: + decode_threshold = self.tree_attn_bias.shape[0] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=decode_threshold + ) + ) + + num_actual_tokens = common_attn_metadata.num_actual_tokens + q_start_loc = common_attn_metadata.query_start_loc + max_query_len = common_attn_metadata.max_query_len + kv_seqlens = common_attn_metadata.seq_lens + max_seq_len = common_attn_metadata.max_seq_len + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + return TreeAttentionMetadata( + num_actual_tokens=num_actual_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_decodes=num_decodes, + max_query_len=max_query_len, + query_start_loc=q_start_loc, + max_seq_len=max_seq_len, + seq_lens=kv_seqlens, + block_table=block_table, + slot_mapping=slot_mapping, + tree_attn_bias=self.tree_attn_bias, + ) + + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int, + ) -> TreeAttentionMetadata: + # Cache the original tree attention bias. + orig_tree_attn_bias = self.tree_attn_bias + + if draft_index == 0: + # Use prefill for drafting at the root level. + self.tree_attn_bias = torch.empty(0) + else: + # Slice the tree attention bias for drafting. Exclude + # the root level. + start, end = 1, 1 + common_attn_metadata.max_query_len + self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() + + # Build attention bias. + attn_metadata = self.build(0, common_attn_metadata, fast_build=True) + + # Reset the tree attention bias to the original value. + self.tree_attn_bias = orig_tree_attn_bias + return attn_metadata + + +def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: + # Count the number of choices at each depth of the tree. + depth_counts = [] + prev_depth = 0 + for path in sorted_tree_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + return depth_counts + + +def _prepare_tree_attn_bias( + sorted_tree_choices: list[tuple[int, ...]], + depth_counts: list[int], + dtype: torch.dtype | None, + device: torch.device | None, +) -> torch.Tensor: + # +1 comes from the additional root node. + tree_len = len(sorted_tree_choices) + 1 + tree_attn_mask = torch.full( + (tree_len, tree_len), -torch.inf, device=device, dtype=dtype + ) + + # Set diagonal to all zeros. Each token should + # attend to itself. + mask_val = 0 + for i in range(tree_len): + tree_attn_mask[i, i] = mask_val + + # Set root to all zeros. All tokens attend to it. + tree_attn_mask[:, 0] = mask_val + + # Set all ancestors to zeros. + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + # Retrieve ancestor position. + if len(cur_tree_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_tree_choice) - 1): + ancestor_idx.append( + sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 + ) + tree_attn_mask[j + start + 1, ancestor_idx] = mask_val + start += depth_counts[i] + return tree_attn_mask + + +class TreeAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if logits_soft_cap is None: + # Setting logits_soft_cap to 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TreeAttentionImpl." + ) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TreeAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with TreeAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for TreeAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + # Cache the input KVs. + key_cache, value_cache = kv_cache.unbind(0) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) + if prefill_meta := attn_metadata.prefill_metadata: + unified_attention( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[num_decode_tokens:num_actual_tokens], + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=prefill_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + unified_attention( + q=query[:num_decode_tokens], + k=key_cache, + v=value_cache, + out=output[:num_decode_tokens], + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_query_len, + seqused_k=decode_meta.seq_lens, + max_seqlen_k=decode_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + qq_bias=decode_meta.tree_attn_bias, + window_size=self.sliding_window, + block_table=decode_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output diff --git a/vllm_fl/v1/attention/backends/triton_attn.py b/vllm_fl/v1/attention/backends/triton_attn.py new file mode 100644 index 00000000..1c5aca52 --- /dev/null +++ b/vllm_fl/v1/attention/backends/triton_attn.py @@ -0,0 +1,500 @@ +# SPDX-License-Identifier: Apache-2.0 +# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""High-Performance Triton-only Attention layer.""" + +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) +from vllm_fl.attention.ops.triton_unified_attention import unified_attention +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import next_power_of_2 +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend + +logger = init_logger(__name__) + + +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel +NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments + + +@dataclass +class TritonAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + seq_threshold_3D: int + num_par_softmax_segments: int + softmax_segm_output: torch.Tensor + softmax_segm_max: torch.Tensor + softmax_segm_expsum: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None + + # Optional aot scheduling + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + + @property + def mm_prefix_range_tensor(self) -> torch.Tensor | None: + """Convert mm_prefix_range dict to padded tensor for Triton kernel. + + Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges. + Empty ranges have start==end==0, which kernel skips via is_valid check. + """ + # TODO(Isotr0py): Move to model runner's attention metadata + # preparation to avoid duplicate computation. + if self.mm_prefix_range is None: + return None + + num_seqs = self.seq_lens.shape[0] + device = self.seq_lens.device + + # Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims + range_lists = [ + self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs) + ] + + # Return None if all ranges are trivial (only (0,0) placeholders) + if all(r == [(0, 0)] for r in range_lists): + return None + + # Create 2D tensors with shape (num_ranges, 2) for each sequence + range_tensors = [ + torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2) + for r in range_lists + ] + + return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0) + + +class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + self.block_size = kv_cache_spec.block_size + + model_config = vllm_config.model_config + self.num_heads_q = model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) + self.headdim = model_config.get_head_size() + + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) + + # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). + # A lower bound for num_q_blocks is the number of sequences. + # To ensure the minimum launch grid size is achieved, the number of sequences + # must be at least equal to the threshold below. + # If this threshold is not reached (i.e., the batch size is not large enough), + # the 3D kernel will be selected instead. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + + # Modify the threshold if needed. + if self.decode_cudagraph_enabled: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified." + + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + self.seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + + self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS + headdim_padded = next_power_of_2(self.headdim) + self.softmax_segm_output = torch.empty( + ( + self.seq_threshold_3D, + self.num_heads_q, + self.num_par_softmax_segments, + headdim_padded, + ), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_max = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_expsum = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + attn_metadata = self.build(0, common_attn_metadata) + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + attn_metadata.seq_lens.fill_(1) + return attn_metadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TritonAttentionMetadata: + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len + suffix_kv_lens = suffix_kv_lens.to(self.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = TritonAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + seq_threshold_3D=self.seq_threshold_3D, + num_par_softmax_segments=self.num_par_softmax_segments, + softmax_segm_output=self.softmax_segm_output, + softmax_segm_max=self.softmax_segm_max, + softmax_segm_expsum=self.softmax_segm_expsum, + ) + return attn_metadata + + +@register_backend(AttentionBackendEnum.TRITON_ATTN) +class MacaTritonAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + + @staticmethod + def get_name() -> str: + return "TRITON_ATTN" + + @staticmethod + def get_impl_cls() -> type["TritonAttentionImpl"]: + return TritonAttentionImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: + return TritonAttentionMetadataBuilder + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size >= 32 + + @classmethod + def supports_mm_prefix(cls) -> bool: + return True + + @classmethod + def supports_sink(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + + +class TritonAttentionImpl(AttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: + raise NotImplementedError( + "Encoder self-attention is not implemented for TritonAttentionImpl" + ) + self.attn_type = attn_type + self.fp8_dtype = current_platform.fp8_dtype() + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}." + ) + + self.supports_quant_query_input = False + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with Paged Attention impl. in Triton. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [num_blocks, 2, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_block_scale is not None: + raise NotImplementedError( + "fused block_scale output quantization is not yet supported" + " for TritonAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + key_cache, value_cache = kv_cache.unbind(1) + + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + if key_cache.dtype != self.fp8_dtype: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + assert layer._q_scale_float == 1.0, ( + "A non 1.0 q_scale is not currently supported." + ) + + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + seq_threshold_3D = attn_metadata.seq_threshold_3D + num_par_softmax_segments = attn_metadata.num_par_softmax_segments + softmax_segm_output = attn_metadata.softmax_segm_output + softmax_segm_max = attn_metadata.softmax_segm_max + softmax_segm_expsum = attn_metadata.softmax_segm_expsum + + descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) + mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor + + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, + sinks=self.sinks, + output_scale=output_scale, + mm_prefix_range=mm_prefix_range_tensor, + ) + + return output