Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 142 additions & 13 deletions vllm_fl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,106 @@
logger = logging.getLogger(__name__)


def _override_flashmla_sparse_backend():
"""强制覆盖 FLASHMLA_SPARSE backend 指向我们的自定义实现"""
try:
from vllm.attention.backends.registry import register_backend, AttentionBackendEnum

# 重新注册 FLASHMLA_SPARSE,覆盖默认的 vLLM 原生实现
register_backend(
AttentionBackendEnum.FLASHMLA_SPARSE,
class_path="vllm_fl.v1.attention.backends.mla.flashmla_sparse.MacaFlashMLASparseBackend",
)

# 同时注册 FLASHMLA
register_backend(
AttentionBackendEnum.FLASHMLA,
class_path="vllm_fl.v1.attention.backends.mla.flashmla.MacaFlashMLABackend",
)

print("[vllm_fl] Successfully overridden FLASHMLA_SPARSE backend to use vllm_fl implementation")
except Exception as e:
print(f"[vllm_fl] Warning: Failed to override backend: {e}")

# 在模块导入时立即执行
_override_flashmla_sparse_backend()

########### platform plugin ###########
def register():
"""Register the FL platform."""
_patch_transformers_compat()

# Model-specific platform patches
from vllm_fl.patches.glm_moe_dsa import apply_platform_patches as glm5_platform
glm5_platform()

multiproc_method = os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
if multiproc_method is None:
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
_get_op_config()

# Check if we're running on MetaX (Maca) platform
# If so, also register MetaX-specific components
if _is_metax_platform():
logger.info("MetaX platform detected, registering MetaX-specific components")
_register_metax_components()

return "vllm_fl.device_platform.PlatformFL"


def _is_metax_platform() -> bool:
"""Detect if running on MetaX (Maca) platform."""
try:
# Check via vendor name from device info
from vllm_fl.utils import DeviceInfo
device_info = DeviceInfo()
return device_info.vendor_name == "metax"
except Exception:
# Fallback: check environment variable or device properties
try:
import torch
device_name = torch.cuda.get_device_name(0).lower()
return "metax" in device_name or "maca" in device_name
except Exception:
return False


def _register_metax_components():
"""
Register MetaX-specific components from vllm_fl.
This ensures compatibility with vllm_fl functionality.
"""
try:
# Import and call vllm_fl register functions
import vllm_fl

# Register MetaX ops (includes patches)
try:
vllm_fl.register_ops()
logger.info("Registered MetaX ops")
except Exception as e:
logger.warning(f"Failed to register MetaX ops: {e}")

# Register MetaX models
try:
vllm_fl.register_model()
logger.info("Registered MetaX models")
except Exception as e:
logger.warning(f"Failed to register MetaX models: {e}")

# Register MetaX quantization configs
try:
vllm_fl.register_quant_configs()
logger.info("Registered MetaX quantization configs")
except Exception as e:
logger.warning(f"Failed to register MetaX quant configs: {e}")

except ImportError:
logger.debug("vllm_fl not available, skipping MetaX component registration")
except Exception as e:
logger.warning(f"Error registering MetaX components: {e}")


def __getattr__(name):
if name == "distributed":
import importlib
Expand All @@ -29,19 +129,17 @@ def _patch_transformers_compat():
)


def register():
"""Register the FL platform."""
_patch_transformers_compat()

# Model-specific platform patches
from vllm_fl.patches.glm_moe_dsa import apply_platform_patches as glm5_platform
glm5_platform()

multiproc_method = os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
if multiproc_method is None:
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
_get_op_config()
return "vllm_fl.platform.PlatformFL"
def register_ops():
"""Register FL ops."""
import vllm_fl.ops # noqa: F401

# Also register MetaX ops if on MetaX platform
if _is_metax_platform():
try:
import vllm_fl
vllm_fl.register_ops()
except Exception as e:
logger.debug(f"Could not register MetaX ops: {e}")


def register_model():
Expand Down Expand Up @@ -116,3 +214,34 @@ def register_model():
)
except Exception as e:
logger.error(f"Register GlmMoeDsa model error: {str(e)}")

# Also register MetaX models if on MetaX platform
if _is_metax_platform():
try:
import vllm_fl
vllm_fl.register_model()
except Exception as e:
logger.debug(f"Could not register MetaX models: {e}")


def register_quant_configs():
"""Register quantization configs."""
# FL-specific quant configs (if any) can be added here

# Also register MetaX quant configs if on MetaX platform
if _is_metax_platform():
try:
import vllm_fl
vllm_fl.register_quant_configs()
except Exception as e:
logger.debug(f"Could not register MetaX quant configs: {e}")


# Backward compatibility: collect_env function
def collect_env() -> None:
"""Collect environment information."""
try:
from vllm_fl.collect_env import main as collect_env_main
collect_env_main()
except ImportError:
logger.debug("vllm_fl.collect_env not available")
174 changes: 174 additions & 0 deletions vllm_fl/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# ---------------------------------------------------
# Note:
#
# Here we only maintain the custom ops that are:
#
# - modified
# - newly added
#
# in vllm_metax compared to vllm.
#
# When *adding* new custom ops, make sure you checked the
# latest vllm/_custom_ops.py first to avoid adding duplicates.
# ---------------------------------------------------
import torch
import vllm.envs as envs


def awq_gemm(
input: torch.Tensor,
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
split_k_iters: int,
temp_space: torch.Tensor,
dtype_bf16: bool,
) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton

return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
return torch.ops._C.awq_gemm(
input, qweight, scales, qzeros, split_k_iters, temp_space, dtype_bf16
)


# awq to gptq 4bit conversion
def awq_to_gptq_4bit(qweight: torch.Tensor) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
return qweight
return torch.ops._C.awq_to_gptq_4bit(qweight)


# gptq
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_exllama: bool,
bit: int,
group_size: int,
perm_space: torch.Tensor,
temp_space: torch.Tensor,
dtype_bf16: bool,
) -> torch.Tensor:
return torch.ops._C.gptq_gemm(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
b_g_idx,
use_exllama,
bit,
group_size,
perm_space,
temp_space,
dtype_bf16,
)


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)


def fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
tileConfig: int,
) -> None:
torch.ops._moe_C.fused_moe_kernel(
A,
B,
C,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
tileConfig,
)


# dsv32
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
kv_cache_dtype: str,
) -> None:
if k.dtype in (torch.bfloat16, torch.float16):
torch.ops._C_cache_ops.indexer_k_cache(k, kv_cache, slot_mapping)
else:
torch.ops._C_cache_ops.indexer_k_quant_and_cache(
k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype
)


def cp_gather_indexer_k_quant_cache(
kv_cache: torch.Tensor,
dst_k: torch.Tensor,
dst_scale: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
) -> None:
if dst_k.dtype in (torch.bfloat16, torch.float16) or dst_scale is None:
torch.ops._C_cache_ops.cp_gather_indexer_k_cache(
kv_cache, dst_k, block_table, cu_seq_lens
)
else:
torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
)


def top_k_per_row(
logits: torch.Tensor,
row_starts: torch.Tensor,
row_ends: torch.Tensor,
topk_indices: torch.Tensor,
num_rows: int,
) -> None:
torch.ops._C.top_k_per_row(
logits,
row_starts,
row_ends,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
)


def top_k_per_row_decode(
logits: torch.Tensor,
next_n: int,
seq_lens: torch.Tensor,
topk_indices: torch.Tensor,
num_rows: int,
) -> None:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
)
37 changes: 37 additions & 0 deletions vllm_fl/attention/fl_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from vllm.logger import init_logger

logger = init_logger(__name__)


def patch_mm_encoder_attention():
"""
Patch vllm.attention.layers.mm_encoder_attention.maybe_get_vit_flash_attn_backend
to support OOT platforms.

The original implementation imports flash_attn_varlen_func from fa_utils,
which may not have it defined for OOT platforms. This patch changes the
FLASH_ATTN branch to import directly from vllm.vllm_flash_attn with a
fallback to flash_attn.
"""
import vllm.attention.layers.mm_encoder_attention as mm_mod
from vllm.attention.backends.registry import AttentionBackendEnum

def _patched_maybe_get_vit_flash_attn_backend(attn_backend):
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func

logger.info_once("Using vllm.vllm_flash_attn for vit attention")
except (ImportError, ModuleNotFoundError):
from flash_attn import flash_attn_varlen_func

logger.info_once("Using flash_attn for vit attention")
return flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func

return flash_attn_varlen_func
else:
return None

mm_mod.maybe_get_vit_flash_attn_backend = _patched_maybe_get_vit_flash_attn_backend
Empty file.
Loading
Loading