From f4756409744bf817bb4702b7fafe5ff08cb48019 Mon Sep 17 00:00:00 2001 From: mslv Date: Wed, 31 Dec 2025 20:36:38 +0800 Subject: [PATCH 01/34] update 0.13.0 --- vllm_fl/__init__.py | 4 - vllm_fl/attention/attention.py | 347 +- vllm_fl/attention/custom_attention.py | 11 + vllm_fl/compilation/graph.py | 31 +- .../device_communicators/flagcx.py | 2 +- vllm_fl/ops/_fl_ops.py | 11 +- vllm_fl/ops/fused_moe/fused_moe.py | 10 +- vllm_fl/ops/fused_moe/layer.py | 208 +- vllm_fl/ops/fused_moe/moe_align_block_size.py | 18 +- vllm_fl/platform.py | 184 +- vllm_fl/worker/model_runner.py | 4603 +++++++++++------ vllm_fl/worker/worker.py | 514 +- 12 files changed, 3861 insertions(+), 2082 deletions(-) create mode 100644 vllm_fl/attention/custom_attention.py diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index fed86075..b185b552 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -5,7 +5,3 @@ def register(): return "vllm_fl.platform.PlatformFL" - -# def register_connector(): -# from vllm_ascend.distributed import register_connector -# register_connector() diff --git a/vllm_fl/attention/attention.py b/vllm_fl/attention/attention.py index 4693ad48..e00a51b8 100644 --- a/vllm_fl/attention/attention.py +++ b/vllm_fl/attention/attention.py @@ -5,7 +5,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import ClassVar import numpy as np import torch @@ -15,48 +15,52 @@ AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.backends.registry import ( + AttentionBackendEnum, + register_backend, +) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.utils import cdiv +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + get_dcp_local_seq_lens, get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.platforms.interface import DeviceCapability from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash logger = init_logger(__name__) + class AttentionFLBackend(AttentionBackend): accept_output_buffer: bool = True - supports_quant_query_input: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + def supports_head_size(cls, head_size: int) -> list[int]: + return head_size % 8 == 0 and head_size <= 256 @staticmethod def get_name() -> str: return "FL" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) @staticmethod def get_impl_cls() -> type["AttentionFLImpl"]: return AttentionFLImpl @@ -68,7 +72,18 @@ def get_metadata_cls() -> type["AttentionMetadata"]: @staticmethod def get_builder_cls() -> type["AttentionFLMetadataBuilder"]: return AttentionFLMetadataBuilder + + @classmethod + def supports_sink(cls) -> bool: + return False + ### TODO(lms): support int8/int4 kv cache + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + return kv_cache_dtype in ["auto"] + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -82,19 +97,41 @@ def get_kv_cache_shape( return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() - if cache_layout == "NHD": + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) + elif cache_layout == "NHD": stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) elif cache_layout == "HND": stride_order = (0, 1, 3, 2, 4) else: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order - ### TODO(lms): support int8/int4 kv cache + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink: + return "not support sink" + return None @dataclass class AttentionFLMetadata: @@ -121,6 +158,10 @@ class AttentionFLMetadata: prefix_kv_lens: torch.Tensor | None suffix_kv_lens: torch.Tensor | None + # For GQA DCP + max_dcp_context_kv_len: int | None = None + dcp_context_kv_lens: torch.Tensor | None = None + # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None @@ -161,7 +202,7 @@ class AttentionFLMetadataBuilder(AttentionMetadataBuilder[AttentionFLMetadata]): # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH def __init__( self, @@ -187,19 +228,26 @@ def __init__( self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = False #get_flash_attn_version() == 3 + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + self.cp_kv_cache_interleave_size = ( + self.parallel_config.cp_kv_cache_interleave_size + ) + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) - self.max_cudagraph_size = self.compilation_config.max_capture_size + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.aot_schedule: - if self.max_cudagraph_size > 992: - # This condition derives from FA3's internal heuristic. - # TODO(woosuk): Support larger cudagraph sizes. - raise ValueError( - "Capture size larger than 992 is not supported for full cuda graph." - ) - self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, @@ -208,7 +256,10 @@ def __init__( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + self.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) + assert self.max_num_splits == 0, "FlagOS only support num_splits is 0 now" # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -230,7 +281,6 @@ def build( max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping causal = common_attn_metadata.causal @@ -263,44 +313,51 @@ def build( max_num_splits = self.max_num_splits use_cascade = common_prefix_len > 0 - - if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True) + use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens = query_start_loc[1:] - query_start_loc[:-1] + dcp_context_kv_lens = seq_lens - query_kv_lens + + dcp_context_kv_lens = get_dcp_local_seq_lens( + dcp_context_kv_lens, + self.dcp_world_size, + self.dcp_rank, + self.cp_kv_cache_interleave_size, + ) + # After DCP distribution, the maximum number of tokens for any rank is + # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, + # and I is cp_kv_cache_interleave_size. + # This eliminates GPU->CPU sync while minimizing workspace over-allocation. + num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size + max_dcp_context_kv_len = ( + (max_seq_len + num_partitions - 1) // num_partitions + ) * self.cp_kv_cache_interleave_size + scheduler_metadata = None + elif use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + # Use GPU tensor directly - no CPU sync needed + suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len prefix_scheduler_metadata = None scheduler_metadata = None - # prefix_scheduler_metadata = schedule( - # batch_size=1, - # cu_query_lens=cu_prefix_query_lens, - # max_query_len=num_actual_tokens, - # seqlens=prefix_kv_lens, - # max_seq_len=common_prefix_len, - # causal=False) - # scheduler_metadata = schedule(batch_size=num_reqs, - # cu_query_lens=query_start_loc, - # max_query_len=max_query_len, - # seqlens=suffix_kv_lens, - # max_seq_len=max_seq_len - - # common_prefix_len, - # causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None scheduler_metadata = None - # scheduler_metadata = schedule(batch_size=num_reqs, - # cu_query_lens=query_start_loc, - # max_query_len=max_query_len, - # seqlens=seq_lens, - # max_seq_len=max_seq_len, - # causal=causal) # For FA3 + full cudagraph if self.use_full_cuda_graph and scheduler_metadata is not None: @@ -321,6 +378,8 @@ def build( seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, @@ -352,7 +411,6 @@ def __init__( logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, - sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -376,22 +434,17 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - AttentionFLBackend.validate_head_size(head_size) - self.attn_type = attn_type self.vllm_flash_attn_version = 2 #get_flash_attn_version() # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "AttentionFL does not support quantization kv-cache on this device." ) - - self.sinks = None - - ### TODO(lms): support int8/int4 attention compute - def supports_quant_query_input(self) -> bool: - return False + ### TODO(lms): support quant to int8/int4 each query input and low precision compute + self.supports_quant_query_input = False def forward( self, @@ -496,29 +549,45 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - ) - return output + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=None, ### self.sinks is support in FA3 + ) + return output # Cascade attention (rare case). cascade_attention( @@ -544,10 +613,90 @@ def forward( q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, - s_aux=self.sinks, + s_aux=None, ## sink is None ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: AttentionFLMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -614,6 +763,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: """Decide whether to use cascade attention. @@ -635,6 +785,9 @@ def use_cascade_attention( num_reqs = len(query_lens) if num_reqs < 8: return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention diff --git a/vllm_fl/attention/custom_attention.py b/vllm_fl/attention/custom_attention.py new file mode 100644 index 00000000..060871a5 --- /dev/null +++ b/vllm_fl/attention/custom_attention.py @@ -0,0 +1,11 @@ +from vllm.attention.backends.registry import ( + AttentionBackendEnum, + register_backend, +) + +def register_attention(): + register_backend( + backend=AttentionBackendEnum.FLASH_ATTN, + class_path="vllm_fl.attention.attention.AttentionFLBackend", + is_mamba=False, + ) \ No newline at end of file diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index c4a35e45..909ec2d7 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -5,19 +5,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from collections import Counter +from collections.abc import Callable from contextlib import ExitStack -from typing import Any, Callable, Optional +from typing import Any, Optional from unittest.mock import patch import torch import vllm.envs as envs from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphOptions from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig -from vllm.distributed.device_communicators.pynccl_allocator import ( - set_graph_pool_id) +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform @@ -27,7 +27,7 @@ def weak_ref_tensors(tensor: Any) -> Any: if current_platform.device_type == "cuda": - from vllm.utils import weak_ref_tensors + from vllm.utils.torch_utils import weak_ref_tensors return weak_ref_tensors(tensor) else: ### TODO: add csrc npu custom op @@ -41,8 +41,7 @@ class Graph: graph = torch.npu.NPUGraph else: raise NotImplementedError("not support graph") - - + @dataclasses.dataclass class GraphEntry: batch_descriptor: BatchDescriptor @@ -53,14 +52,18 @@ class GraphEntry: # during capture, and check if they are the same during replay input_addresses: Optional[list[int]] = None +@dataclasses.dataclass +class GraphOptions: + debug_log_enable: bool = True + gc_disable: bool = False + weak_ref_output: bool = True class GraphWrapper: def __init__(self, runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - graph_pool: Any = None, - cudagraph_options: Optional[CUDAGraphOptions] = None): + cudagraph_options: Optional[GraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config self.runtime_mode = runtime_mode @@ -78,7 +81,7 @@ def __init__(self, self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: - cudagraph_options = CUDAGraphOptions() + cudagraph_options = GraphOptions() self.graph_options = cudagraph_options # the entries for different batch descriptors that we need to capture # cudagraphs for. @@ -153,6 +156,14 @@ def __call__(self, *args, **kwargs): with current_platform.torch_device_fn.graph(graph, pool=self.graph_pool): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) + if self.graph_options.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph in piecewise cuadgraph mode, because + # the output of the last graph will not be used by + # any other cuda graph. + output = weak_ref_tensors(output) entry.output = weak_ref_tensors(output) entry.graph = graph diff --git a/vllm_fl/distributed/device_communicators/flagcx.py b/vllm_fl/distributed/device_communicators/flagcx.py index 74a750d6..1407757f 100644 --- a/vllm_fl/distributed/device_communicators/flagcx.py +++ b/vllm_fl/distributed/device_communicators/flagcx.py @@ -13,7 +13,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream import os import sys diff --git a/vllm_fl/ops/_fl_ops.py b/vllm_fl/ops/_fl_ops.py index 5535b2c2..bd90ce62 100644 --- a/vllm_fl/ops/_fl_ops.py +++ b/vllm_fl/ops/_fl_ops.py @@ -17,7 +17,7 @@ def gelu_and_mul(x, approximate="none"): d = x.shape[-1] // 2 x1, x2 = x[..., :d], x[..., d:] return flag_gems.fused.gelu_and_mul(x1, x2, approximate) - + ### moe @staticmethod def topk_softmax(topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False): @@ -31,5 +31,12 @@ def topk_softmax(topk_weights, topk_indices, token_expert_indices, gating_output topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_indices + @staticmethod def moe_sum(input, output): - flag_gems.moe_sum(input, output) \ No newline at end of file + flag_gems.moe_sum(input, output) + + @staticmethod + def moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad,): + flag_gems.moe_align_block_size_triton(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad,) \ No newline at end of file diff --git a/vllm_fl/ops/fused_moe/fused_moe.py b/vllm_fl/ops/fused_moe/fused_moe.py index 8b01d842..cb4e328b 100644 --- a/vllm_fl/ops/fused_moe/fused_moe.py +++ b/vllm_fl/ops/fused_moe/fused_moe.py @@ -119,14 +119,14 @@ def fused_experts_impl( config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=False, ## dont support mxfp4 + ocp_mx_scheme=None, ## dont support mxfp4 dtype=hidden_states.dtype) # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are # quantized prior to calling fused_experts. quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, - use_mxfp4_w4a4=False) + ocp_mx_scheme=None) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -199,7 +199,7 @@ def fused_experts_impl( sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + global_num_experts, expert_map, ignore_invalid_experts=True,)) invoke_fused_moe_kernel(qcurr_hidden_states, w1, @@ -280,12 +280,10 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - quant_config: Optional[FusedMoEQuantConfig] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG diff --git a/vllm_fl/ops/fused_moe/layer.py b/vllm_fl/ops/fused_moe/layer.py index e3d3ba50..ab120ff1 100644 --- a/vllm_fl/ops/fused_moe/layer.py +++ b/vllm_fl/ops/fused_moe/layer.py @@ -36,66 +36,30 @@ def _eplb_map_to_physical_and_record( class UnquantizedFusedMoEMethodFL(UnquantizedFusedMoEMethod): - def forward_oot(self, - layer: torch.nn.Module, + def forward_oot( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - zero_expert_num = getattr(layer, 'zero_expert_num', 0) - zero_expert_type = getattr(layer, 'zero_expert_type', None) - topk_weights, topk_ids, zero_expert_result = FusedMoEFL.select_experts( + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + topk_weights, topk_ids, zero_expert_result = layer.select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type) - + ) result = fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, + activation=layer.activation, quant_config=self.moe_quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, ) - if zero_expert_num != 0 and zero_expert_type is not None: + if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: assert not isinstance(result, tuple), \ "Shared + zero experts are mutually exclusive not yet supported" return result, zero_expert_result @@ -126,35 +90,17 @@ def forward_oot(self, return (shared_output[..., :og_hidden_states], fused_output[..., :og_hidden_states]) - @staticmethod def select_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None, - enable_eplb: bool = False, - expert_map: Optional[torch.Tensor] = None, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - global_num_experts: Optional[int] = None, - zero_expert_num: Optional[int] = None, - zero_expert_type: Optional[str] = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """ Route the input hidden states to the top-k experts based on the router logits. Returns: - (topk_weights, topk_ids, zero_expert_result) + (topk_weights, topk_ids, zero_expert_result) (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The weights, expert ids, and zero expert computation result. @@ -164,6 +110,36 @@ def select_experts( """ from vllm_fl.ops.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk_bias + + if self.enable_eplb: + if self.quant_method.supports_eplb: + if self.expert_load_view is None: + raise ValueError( + "enable_eplb=True requiere expert_load_view != None" + ) + if self.logical_to_physical_map is None: + raise ValueError( + "enable_eplb=True requiere logical_to_physical_map != None" + ) + if self.logical_replica_count is None: + raise ValueError( + "enable_eplb=True requiere logical_replica_count != None" + ) + else: + raise NotImplementedError( + f"EPLB is not supported for {self.quant_method.method_name}." + ) + + def valid_grouping() -> bool: + # Check if num_experts is greater than num_expert_group + # and is divisible by num_expert_group + num_experts = router_logits.shape[-1] + if num_experts <= self.num_expert_group: + return False + return num_experts % self.num_expert_group == 0 + + indices_type = self.quant_method.topk_indices_dtype + # Check if we should use a routing simulation strategy routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY if routing_strategy != "": @@ -171,76 +147,86 @@ def select_experts( hidden_states=hidden_states, router_logits=router_logits, strategy_name=routing_strategy, - top_k=top_k, - indices_type=indices_type) + top_k=self.top_k, + indices_type=indices_type, + ) # DeepSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( + elif self.use_grouped_topk and valid_grouping(): + assert self.topk_group is not None + assert self.num_expert_group is not None + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): + assert self.num_fused_shared_experts == 0 + grouped_topk_impl = partial( + rocm_aiter_grouped_topk, + num_fused_shared_experts=self.num_fused_shared_experts, + ) + else: + grouped_topk_impl = grouped_topk + + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) - if indices_type is not None: - topk_ids = topk_ids.to(dtype=indices_type) - elif e_score_correction_bias is not None: + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + ) + elif self.e_score_correction_bias is not None: topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, gating_output=router_logits, - e_score_correction_bias=e_score_correction_bias.data, - topk=top_k, - renormalize=renormalize, + e_score_correction_bias=self.e_score_correction_bias.data, + topk=self.top_k, + renormalize=self.renormalize, ) - if routed_scaling_factor is not None: - topk_weights *= routed_scaling_factor - elif custom_routing_function is None: + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor + elif self.custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states=hidden_states, gating_output=router_logits, - topk=top_k, - renormalize=renormalize, + topk=self.top_k, + renormalize=self.renormalize, indices_type=indices_type, ) else: - topk_weights, topk_ids = custom_routing_function( + topk_weights, topk_ids = self.custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - if indices_type is not None: - topk_ids = topk_ids.to(dtype=indices_type) - - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None + topk=self.top_k, + renormalize=self.renormalize, + ) + if self.enable_eplb: topk_ids = eplb_map_to_physical_and_record( topk_ids=topk_ids, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - indices_type=indices_type, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) + if (indices_type is not None) and topk_ids.dtype != indices_type: + topk_ids = topk_ids.to(dtype=indices_type) + assert topk_ids.dtype == indices_type or indices_type is None # Compute zero expert result if needed - if (zero_expert_num is not None and zero_expert_num > 0 - and zero_expert_type is not None - and global_num_experts is not None): + if ( + self.zero_expert_num is not None + and self.zero_expert_num > 0 + and self.zero_expert_type is not None + and self.global_num_experts is not None + ): zero_expert_result = zero_experts_compute_triton( expert_indices=topk_ids, expert_scales=topk_weights, - num_experts=global_num_experts, - zero_expert_type=zero_expert_type, + num_experts=self.global_num_experts, + zero_expert_type=self.zero_expert_type, hidden_states=hidden_states, ) else: diff --git a/vllm_fl/ops/fused_moe/moe_align_block_size.py b/vllm_fl/ops/fused_moe/moe_align_block_size.py index 52d131a6..d287f0e2 100644 --- a/vllm_fl/ops/fused_moe/moe_align_block_size.py +++ b/vllm_fl/ops/fused_moe/moe_align_block_size.py @@ -7,12 +7,9 @@ from typing import Optional import torch - -from vllm import _custom_ops as ops from vllm.triton_utils import triton -from vllm.utils import round_up - -import flag_gems +from vllm.utils.math_utils import round_up +from vllm_fl.ops._fl_ops import FLOps as fl_ops def moe_align_block_size( @@ -20,7 +17,8 @@ def moe_align_block_size( block_size: int, num_experts: int, expert_map: Optional[torch.Tensor] = None, - pad_sorted_ids: bool = False + pad_sorted_ids: bool = False, + ignore_invalid_experts: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -84,9 +82,11 @@ def moe_align_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - - flag_gems.moe_align_block_size_triton(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + # TODO(lms): ignore_invalid_experts not effective now + # moe_align_block_size has optimize version to filtered out + # all invalid experts directly when counting the number of experts + fl_ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad,) # ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, # expert_ids, num_tokens_post_pad) if expert_map is not None: diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 421ce2c5..7c5c3a70 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -8,24 +8,43 @@ from datetime import timedelta from functools import cache, wraps from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec import torch -import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.platforms import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig + from vllm.config.cache import CacheDType else: - _Backend = None + VllmConfig = None + CacheDType = None from vllm_fl.utils import DeviceInfo logger = init_logger(__name__) +_P = ParamSpec("_P") +_R = TypeVar("_R") + +@cache +def _get_backend( + use_mla: bool, + device_info: Optional[DeviceInfo] = None, +) -> list[str]: + """Get backend priorities with lazy import to avoid circular dependency.""" + if use_mla: + raise NotImplementedError("NOT support mla now!") + # return "vllm_fl.attention.backends.mla.MLAFLBackend" + else: + return AttentionBackendEnum.FLASH_ATTN #"vllm_fl.attention.attention.AttentionFLBackend" + + class PlatformFL(Platform): _enum = PlatformEnum.OOT device_info = DeviceInfo() @@ -34,7 +53,7 @@ class PlatformFL(Platform): dispatch_key = device_info.dispatch_key torch_device_fn = device_info.torch_device_fn ray_device_key: str = "flagos" - dist_backend: str = "flagcx" + dist_backend: str = "flagcx" if "FLAGCX_PATH" in os.environ else "nccl" ### TODO(lms): dispatch device_control_env_var # device_control_env_var: str = "CUDA_VISIBLE_DEVICES" @@ -79,16 +98,6 @@ def get_device_capability(cls, device_id: int = 0): @classmethod def get_device_name(cls, device_id: int = 0) -> str: return cls.device_name - - @classmethod - def verify_quantization(cls, quant: str) -> None: - """ - Verify whether the quantization is supported by the current platform. - """ - if cls.supported_quantization and quant not in cls.supported_quantization: - raise ValueError( - f"{quant} quantization is currently not supported in {cls.device_name}." - ) ### TODO(lms): change pin_memory depend device @classmethod @@ -149,34 +158,51 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_v1, - use_mla, - has_sink, - use_sparse, - ) -> str: - - ### TODO(lms): support int8 kv cache - # use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - # "fp8" - # ) - - if use_mla: - ### TODO(lms): support mla - raise NotImplementedError - # logger.info_once("Using FL MLA Attention backend.") - # return ( - # "vllm_fl.attention.backends.mla.MLAFLBackend" - # ) - else: - logger.info_once("Using FL Attention backend.") - return ( - "vllm_fl.attention.attention.AttentionFLBackend" + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", + ) -> list[str]: + from vllm_fl.attention.custom_attention import register_attention + register_attention() + backend = _get_backend( + use_mla=False, + device_info=cls.device_info, + ) + backend_class = backend.get_class() + invalid_reasons = backend_class.validate_configuration( + device_capability=None, + **attn_selector_config._asdict(), ) + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(invalid_reasons)}]" + ) + + "}" + ) + config_str = attn_selector_config.__repr__() + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + + logger.info_once( + "Using %s attention backend out of potential backends: %s", + backend.name, + tuple(backend.name), + scope="local", + ) + return backend.get_path() + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> list[str]: + return [ + "vllm_fl.attention.attention.AttentionFLBackend" + ] @classmethod def get_punica_wrapper(cls) -> str: @@ -185,9 +211,16 @@ def get_punica_wrapper(cls) -> str: @classmethod def get_device_communicator_cls(cls) -> str: - return ( - "vllm_fl.distributed.communicator.CommunicatorFL" # noqa - ) + if cls.dist_backend == "flagcx": + logger.info("Using CommunicatorFL for communication.") + return ( + "vllm_fl.distributed.communicator.CommunicatorFL" # noqa + ) + else: + logger.info("Using CudaCommunicator for communication.") + return ( + "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + ) @classmethod @@ -196,8 +229,34 @@ def get_static_graph_wrapper_cls(cls) -> str: @classmethod def support_static_graph_mode(cls) -> bool: - return True + if cls.device_name in ["cuda", "npu"]: + return True + return False + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache device .""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from device to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() + @classmethod def support_hybrid_kv_cache(cls) -> bool: return True @@ -207,3 +266,38 @@ def support_hybrid_kv_cache(cls) -> bool: def opaque_attention_op(cls) -> bool: return True + @classmethod + def use_custom_allreduce(cls) -> bool: + if cls.dist_backend == "flagcx": + return False + return True + + @classmethod + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: + try: + import pynvml + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, + peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK, + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if" + " your machine has no NVLink equipped." + ) + return False + return True + except: + return False + diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index 7e3bd21e..46b654f0 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -3,95 +3,178 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import gc import itertools import time from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from contextlib import contextmanager, nullcontext -from copy import deepcopy -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast +from copy import copy, deepcopy +from functools import reduce +from itertools import product +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np import torch import torch.distributed import torch.nn as nn from tqdm import tqdm -from typing_extensions import TypeAlias import vllm.envs as envs -from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention +from vllm.attention.backends.abstract import( + AttentionBackend, + AttentionType, + MultipleOf) +from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter -# from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.cuda_graph import CUDAGraphStat #CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import ( + CompilationMode, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, is_global_first_rank, GraphCaptureContext, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) + get_dcp_group, + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, + GraphCaptureContext +) +from vllm.forward_context import ( + BatchDescriptor, + set_forward_context, +) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import ( + MRotaryEmbedding, + XDRotaryEmbedding, +) from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - is_mixture_of_experts, - supports_eagle3, - supports_mrope, - supports_multimodal_pruning, - supports_transcription) -# yapf: enable +from vllm.model_executor.models.interfaces import ( + SupportsMRoPE, + SupportsMultiModal, + SupportsXDRoPE, + is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_multimodal_pruning, + supports_transcription, + supports_xdrope, +) from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) + VllmModelForPooling, + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, cdiv, check_use_alibi, get_dtype_size, - is_pin_memory_available, - length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo) +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.jsontree import json_map_leaves +from vllm.utils.math_utils import cdiv, round_up +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.nvtx_pytorch_hooks import PytHooks +from vllm.utils.platform_utils import is_pin_memory_available + from vllm.platforms import current_platform +if current_platform.dist_backend == "flagcx": + @contextmanager + def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the NPU graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current NPU stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + graph_capture_context = GraphCaptureContext( + current_platform.torch_device_fn.Stream(device=device)) + stream = graph_capture_context.stream + + # we use nullcontext now + maybe_ca_context = nullcontext() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = current_platform.torch_device_fn.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with current_platform.torch_device_fn.stream(stream), maybe_ca_context: + yield graph_capture_context +else: + from vllm.distributed.parallel_state import graph_capture +from vllm.utils.torch_utils import ( + get_dtype_size, + kv_cache_dtype_str_to_dtype, + supports_dynamo, +) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) + get_dcp_local_seq_lens, + reorder_batch_to_split_decodes_and_prefills, + split_attn_metadata, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -# yapf conflicts with isort for this block -# yapf: disable -from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - CrossAttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) -# yapf: enable -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, PoolerOutput, SamplerOutput) -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + ECConnectorOutput, + KVConnectorOutput, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, + SamplerOutput, + make_empty_encoder_model_runner_output, +) +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs +from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -99,27 +182,35 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.cp_utils import check_attention_cp_compatibility +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp +from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, - ubatch_split) -from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices +from vllm.v1.worker.ubatch_utils import ( + UBatchSlices, + check_ubatch_thresholds, + maybe_create_ubatch_slices, +) from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.workspace import lock_workspace -from vllm.v1.worker.utils import (AttentionGroup, MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) - +from vllm.v1.worker.utils import ( + AttentionGroup, + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + sanity_check_mm_encoder_outputs, +) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm_fl.compilation.graph import GraphWrapper from vllm_fl.attention.attention import AttentionMetadata @@ -129,39 +220,7 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled -PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], - AttnMetadataDict] - -@contextmanager -def graph_capture(device: torch.device): - """ - `graph_capture` is a context manager which should surround the code that - is capturing the NPU graph. Its main purpose is to ensure that the - some operations will be run after the graph is captured, before the graph - is replayed. It returns a `GraphCaptureContext` object which contains the - necessary data for the graph capture. Currently, it only contains the - stream that the graph capture is running on. This stream is set to the - current NPU stream when the context manager is entered and reset to the - default stream when the context manager is exited. This is to ensure that - the graph capture is running on a separate stream from the default stream, - in order to explicitly distinguish the kernels to capture - from other kernels possibly launched on background in the default stream. - """ - graph_capture_context = GraphCaptureContext( - current_platform.torch_device_fn.Stream(device=device)) - stream = graph_capture_context.stream - - # we use nullcontext now - maybe_ca_context = nullcontext() - - # ensure all initialization operations complete before attempting to - # capture the graph on another stream - curr_stream = current_platform.torch_device_fn.current_stream() - if curr_stream != stream: - stream.wait_stream(curr_stream) - - with current_platform.torch_device_fn.stream(stream), maybe_ca_context: - yield graph_capture_context +PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict # Wrapper for ModelRunnerOutput to support overlapped execution. @@ -171,48 +230,87 @@ def __init__( self, model_runner_output: ModelRunnerOutput, sampled_token_ids: torch.Tensor, + logprobs_tensors: LogprobsTensors | None, invalid_req_indices: list[int], - async_output_copy_stream, + async_output_copy_stream: current_platform.torch_device_fn.Stream | None, + vocab_size: int, ): self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices # Event on the copy stream so we can synchronize the non-blocking copy. - self._async_copy_ready_event = current_platform.torch_device_fn.Event() + self.async_copy_ready_event = torch.Event() # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. self._sampled_token_ids = sampled_token_ids + self.vocab_size = vocab_size + self._logprobs_tensors = logprobs_tensors # Initiate the copy on a separate stream, but do not synchronize it. default_stream = current_platform.torch_device_fn.current_stream() with current_platform.torch_device_fn.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) - self._sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True) - self._async_copy_ready_event.record() + self.sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True + ) + self._logprobs_tensors_cpu = ( + self._logprobs_tensors.to_cpu_nonblocking() + if self._logprobs_tensors + else None + ) + self.async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. This function blocks until the copy is finished. """ - self._async_copy_ready_event.synchronize() + max_gen_len = self.sampled_token_ids_cpu.shape[-1] + self.async_copy_ready_event.synchronize() - # Release the device tensor once the copy has completed + # Release the device tensors once the copy has completed. + del self._logprobs_tensors del self._sampled_token_ids + if max_gen_len == 1: + valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + cu_num_tokens = None + else: + valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + self.sampled_token_ids_cpu, + self.vocab_size, + self._invalid_req_indices, + return_cu_num_tokens=self._logprobs_tensors_cpu is not None, + ) - valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() - for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids + if self._logprobs_tensors_cpu: + output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens) return output -class ModelRunnerFL(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): +class ExecuteModelState(NamedTuple): + """Ephemeral cached state transferred between execute_model() and + sample_tokens(), after execute_model() returns None.""" + scheduler_output: "SchedulerOutput" + logits: torch.Tensor + spec_decode_metadata: SpecDecodeMetadata | None + spec_decode_common_attn_metadata: CommonAttentionMetadata | None + hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + ec_connector_output: ECConnectorOutput | None + cudagraph_stats: CUDAGraphStat | None + + +class ModelRunnerFL( + LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin +): def __init__( self, vllm_config: VllmConfig, @@ -230,8 +328,8 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) model_config = self.model_config cache_config = self.cache_config @@ -240,20 +338,24 @@ def __init__( self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] - self.is_pooling_model = (model_config.runner_type == 'pooling') + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( + cache_config.cache_dtype, self.model_config + ) + + self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( - model_config.is_multimodal_raw_input_only_model) + model_config.is_multimodal_raw_input_only_model + ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + + # Always set to false after the first forward pass + self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size + self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -262,37 +364,39 @@ def __init__( # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( - self.parallel_config.distributed_executor_backend - == "external_launcher" and len(get_pp_group().ranks) > 0) + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) # Model-related. - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.hidden_size = model_config.get_hidden_size() + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) + self.inputs_embeds_size = model_config.get_inputs_embeds_size() self.attention_chunk_size = model_config.attention_chunk_size # Only relevant for models using ALiBi (e.g, MPT) - self.use_alibi = check_use_alibi(model_config) + self.use_alibi = model_config.uses_alibi self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.uses_xdrope_dim = model_config.uses_xdrope_dim self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) if self.model_config.is_encoder_decoder: # Maximum length of the encoder input, only for encoder-decoder # models. - self.max_encoder_len = scheduler_config.\ - max_num_encoder_input_tokens + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens else: self.max_encoder_len = 0 # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) - self.eplb_state: Optional[EplbState] = None + self.eplb_state: EplbState | None = None """ State of the expert parallelism load balancer. @@ -303,6 +407,9 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + # Initialize in initialize_kv_cache_tensors + self.cross_layers_kv_cache: torch.Tensor | None = None + self.cross_layers_attn_backend: type[AttentionBackend] | None = None # indexes: [kv_cache_group_id][attn_group] self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig @@ -316,24 +423,38 @@ def __init__( # the last PP rank. This is not ideal if there are many # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: + self.drafter: ( + NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "suffix": + self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True + self.use_aux_hidden_state_outputs = ( + self.drafter.eagle3_use_aux_hidden_state + ) elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore + vllm_config=self.vllm_config, device=self.device + ) else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") - self.rejection_sampler = RejectionSampler() + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) + self.rejection_sampler = RejectionSampler(self.sampler) + self.num_spec_tokens = 0 + if self.speculative_config: + self.num_spec_tokens = self.speculative_config.num_speculative_tokens # Request states. self.requests: dict[str, CachedRequestState] = {} + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: dict[str, int] = {} self.comm_stream = current_platform.torch_device_fn.Stream() # Input Batch @@ -345,6 +466,10 @@ def __init__( # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. + logits_processors = model_config.logits_processors + custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( + tuple(logits_processors) if logits_processors is not None else () + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoer @@ -355,52 +480,70 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + custom_logitsprocs, + ), + # We currently don't know whether a particular custom logits processor + # uses output token ids so we set this conservatively. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, + cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) self.use_async_scheduling = self.scheduler_config.async_scheduling self.async_output_copy_stream = current_platform.torch_device_fn.Stream() if \ self.use_async_scheduling else None + # cuda event to synchronize use of reused CPU tensors between steps + # when async scheduling is enabled. + self.prepare_inputs_event: torch.Event | None = None + if self.use_async_scheduling: + self.prepare_inputs_event = torch.Event() - # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. - # The batch sizes in the config are in descending order. - if self.compilation_config.cudagraph_capture_sizes and \ - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: - self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) - + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + self.cudagraph_batch_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes + ) # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, - dtype=torch.int64) - self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, - dtype=torch.int32) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + if self.dcp_world_size > 1: + self.dcp_local_seq_lens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. - self.inputs_embeds = self._make_buffer(self.max_num_tokens, - self.hidden_size, - dtype=self.dtype, - numpy=False) - self.is_token_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) - self.discard_request_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) - self.num_discarded_requests = 0 - - self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) - self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.inputs_embeds_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_mask = self._make_buffer( + self.max_num_reqs, dtype=torch.bool + ) + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -415,27 +558,25 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64) + (3, self.max_num_tokens + 1), dtype=torch.int64 + ) - ### TODO(lms): support prepare input event - self.prepare_inputs_event = None - # # CUDA event to synchronize use of reused CPU tensors between steps - # # when async scheduling is enabled. - self.prepare_inputs_event: Optional[current_platform.torch_device_fn.Event] = None - if self.use_async_scheduling: - self.prepare_inputs_event = current_platform.torch_device_fn.Event() - # Start in a completed state. - self.prepare_inputs_event.record(current_platform.torch_device_fn.default_stream()) + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + # Similar to mrope but use assigned dimension number for RoPE, 4 as default. + self.xdrope_positions = self._make_buffer( + (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64 + ) # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.intermediate_tensors: IntermediateTensors | None = None # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -447,21 +588,25 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device) + self.max_num_tokens, dtype=torch.int32, device=self.device + ) - self.uniform_decode_query_len = 1 if not self.speculative_config else \ - 1 + self.speculative_config.num_speculative_tokens + self.uniform_decode_query_len = 1 + self.num_spec_tokens # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) - self.reorder_batch_threshold: Optional[int] = None + self.reorder_batch_threshold: int | None = None # Attention layers that are only in the KVCacheConfig of the runner # (e.g., KV sharing, encoder-only attention), but not in the @@ -469,24 +614,105 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], - torch.Tensor]] = None - self.transfer_event = current_platform.torch_device_fn.Event() + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self.transfer_event = torch.Event() self.sampled_token_ids_pinned_cpu = torch.empty( - (self.max_model_len, 1), + (self.max_num_reqs, 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory) - - def _make_buffer(self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy) + pin_memory=self.pin_memory, + ) + # Pre-allocated tensor for copying valid sampled token counts to CPU, + # with dedicated stream for overlapping and event for coordination. + self.valid_sampled_token_count_event: torch.Event | None = None + self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None + if self.use_async_scheduling and self.num_spec_tokens: + self.valid_sampled_token_count_event = torch.Event() + self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + + # Ephemeral state transferred between execute_model() and sample_tokens(). + self.execute_model_state: ExecuteModelState | None = None + self.kv_connector_output: KVConnectorOutput | None = None + self.layerwise_nvtx_hooks_registered = False + + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + + @torch.inference_mode() + def init_fp8_kv_scales(self) -> None: + """ + Re-initialize the KV cache and FP8 scales after waking from sleep. + 1. Zero out the KV cache tensors to remove garbage data from re-allocation. + 2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0. + If these are left at 0.0 (default after wake_up), all KV cache values + become effectively zero, causing gibberish output. + """ + if not self.cache_config.cache_dtype.startswith("fp8"): + return + + kv_caches = getattr(self, "kv_caches", []) + for cache_tensor in kv_caches: + if cache_tensor is not None: + cache_tensor.zero_() + + k_attr_names = ("_k_scale", "k_scale") + v_attr_names = ("_v_scale", "v_scale") + + attn_layers = self.compilation_config.static_forward_context + for name, module in attn_layers.items(): + if isinstance(module, (Attention, MLAAttention)): + # TODO: Generally, scale is 1.0 if user uses on-the-fly fp8 + # kvcache quant. However, to get better accuracy, compression + # frameworks like llm-compressors allow users to tune the + # scale. We may need to restore the specific calibrated scales + # here in the future. + k_scale_val, v_scale_val = 1.0, 1.0 + + # Processing K Scale + for attr in k_attr_names: + if hasattr(module, attr): + param = getattr(module, attr) + if isinstance(param, torch.Tensor): + param.fill_(k_scale_val) + + # Processing V Scale + for attr in v_attr_names: + if hasattr(module, attr): + param = getattr(module, attr) + if isinstance(param, torch.Tensor): + param.fill_(v_scale_val) + + def _get_positions(self, num_tokens: Any): + if isinstance(num_tokens, int): + if self.uses_mrope: + return self.mrope_positions.gpu[:, :num_tokens] + if self.uses_xdrope_dim > 0: + return self.xdrope_positions.gpu[:, :num_tokens] + return self.positions.gpu[:num_tokens] + else: + if self.uses_mrope: + return self.mrope_positions.gpu[:, num_tokens] + if self.uses_xdrope_dim > 0: + return self.xdrope_positions.gpu[:, num_tokens] + return self.positions.gpu[num_tokens] + + def _make_buffer( + self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy, + ) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -499,9 +725,11 @@ def _init_model_kwargs(self, num_tokens: int): token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) + is not None + ): token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: @@ -516,7 +744,8 @@ def _init_model_kwargs(self, num_tokens: int): token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) + device=self.device + ) return model_kwargs def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -538,16 +767,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - # NOTE(lucas): currently no backend supports the custom masking - # required for DCP with q_len > 1, so we assert here. Remove this - # assert once the custom mask is support is added to FA3. - if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ - "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + ) # Note: used for model runner override. def _sync_device(self) -> None: @@ -566,6 +790,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -586,7 +811,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct @@ -601,8 +833,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -632,20 +866,62 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: ) self.requests[req_id] = req_state + if sampling_params and sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = ( + self.input_batch.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._init_xdrope_positions(req_state) + reqs_to_add.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() + for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt lenth), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_state.prev_num_draft_len: + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -657,42 +933,57 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: + assert req_index is None assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + reqs_to_add.append(req_state) continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -701,24 +992,45 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, [] + ) + num_spec_tokens = len(spec_token_ids) + # For async scheduling, token_ids_cpu assigned from + # spec_token_ids are placeholders and will be overwritten in + # _prepare_input_ids. + if num_spec_tokens: start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids + req_index, start_index:end_token_index + ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens + # When speculative decoding is used with structured output, + # the scheduler can drop draft tokens that do not + # conform to the schema. This can result in + # scheduler_output.scheduled_spec_decode_tokens being empty, + # even when speculative decoding is enabled. + self.input_batch.spec_token_ids[req_index].clear() + self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) + + # there are no draft tokens with async scheduling, + # we clear the spec_decoding info in scheduler_output and + # use normal sampling but rejection_sampling. + if self.use_async_scheduling: + req_state.prev_num_draft_len = num_spec_tokens + if num_spec_tokens and self._draft_token_ids is None: + scheduler_output.total_num_scheduled_tokens -= num_spec_tokens + scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens + scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: @@ -732,7 +1044,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor) -> None: + self, output_token_ids: torch.Tensor + ) -> None: """Update the cached states after model execution. This is used for MTP/EAGLE for hybrid models, as in linear attention, @@ -745,61 +1058,56 @@ def _update_states_after_model_execute( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = (torch.cat( - [ - output_token_ids, - torch.full((output_token_ids.size(0), 1), - -1, - device=output_token_ids.device), - ], - dim=1) == -1).int().argmax(-1).cpu().numpy() + num_accepted_tokens = ( + ( + torch.cat( + [ + output_token_ids, + torch.full( + (output_token_ids.size(0), 1), + -1, + device=output_token_ids.device, + ), + ], + dim=1, + ) + == -1 + ) + .int() + .argmax(-1) + .cpu() + .numpy() + ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _init_mrope_positions(self, req_state: CachedRequestState): - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_feature in req_state.mm_features: - mm_item = mm_feature.data - if mm_item is None: - continue - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = \ - self.model.get_mrope_input_positions( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - else: - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + model = self.get_model() + assert supports_mrope(model), "M-RoPE support is not implemented." + assert req_state.prompt_token_ids is not None, ( + "M-RoPE requires prompt_token_ids to be available." + ) + mrope_model = cast(SupportsMRoPE, model) + + req_state.mrope_positions, req_state.mrope_position_delta = ( + mrope_model.get_mrope_input_positions( + req_state.prompt_token_ids, + req_state.mm_features, + ) + ) + + def _init_xdrope_positions(self, req_state: CachedRequestState): + model = self.get_model() + xdrope_model = cast(SupportsXDRoPE, model) + assert req_state.prompt_token_ids is not None, ( + "XD-RoPE requires prompt_token_ids to be available." + ) + assert supports_xdrope(model), "XD-RoPE support is not implemented." + + req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions( + req_state.prompt_token_ids, + req_state.mm_features, + ) def _extract_mm_kwargs( self, @@ -818,10 +1126,10 @@ def _extract_mm_kwargs( model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -840,7 +1148,7 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: def _get_cumsum_and_arange( self, num_tokens: np.ndarray, - cumsum_dtype: Optional[np.dtype] = None, + cumsum_dtype: np.dtype | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Get the cumulative sum and batched arange of the given array. # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) @@ -857,10 +1165,14 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _prepare_input_ids(self, total_num_scheduled_tokens: int, - cu_num_tokens: np.ndarray) -> None: + def _prepare_input_ids( + self, + scheduler_output: "SchedulerOutput", + total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray, + ) -> None: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -878,21 +1190,43 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # on the GPU from prev_sampled_token_ids. prev_req_id_to_index = self.input_batch.prev_req_id_to_index assert prev_req_id_to_index is not None - flattened_indices = [] - prev_common_req_indices = [] + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] indices_match = True max_flattened_index = -1 + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id, cur_index in self.input_batch.req_id_to_index.items(): if (prev_index := prev_req_id_to_index.get(req_id)) is not None: prev_common_req_indices.append(prev_index) # We need to compute the flattened input_ids index of the # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len flattened_index = cu_num_tokens[cur_index].item() - 1 - flattened_indices.append(flattened_index) - indices_match &= (prev_index == flattened_index) + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, flattened_index + 1) + ) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, start + draft_len)) + indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) - num_commmon_tokens = len(flattened_indices) - if num_commmon_tokens < total_num_scheduled_tokens: + num_commmon_tokens = len(sample_flattened_indices) + total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens + if num_commmon_tokens < total_without_spec: # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) @@ -901,7 +1235,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) if num_commmon_tokens == 0: # No requests in common with the previous iteration - # So input_ids_cpu will have all the input ids. + # So input_ids.cpu will have all the input ids. return if indices_match and max_flattened_index == (num_commmon_tokens - 1): # Common-case optimization: the batch is unchanged @@ -909,57 +1243,95 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, - 0], - non_blocking=True) + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + non_blocking=True, + ) if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return - # Upload the index tensors asynchronously - # so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor(flattened_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to( - self.device, - non_blocking=True) + # Upload the index tensors asynchronously so the scatter can be non-blocking. + sampled_tokens_index_tensor = torch.tensor( + sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to(self.device, non_blocking=True) + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, - index=input_ids_index_tensor, + index=sampled_tokens_index_tensor, src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0]) + prev_common_req_indices_tensor, 0 + ], + ) + + # Scatter the draft tokens after the sampled tokens are scattered. + if self._draft_token_ids is None or not spec_flattened_indices: + return + + assert isinstance(self._draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor( + spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + + # because input_ids dtype is torch.int32, + # so convert draft_token_ids to torch.int32 here. + draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) + self._draft_token_ids = None + + self.input_ids.gpu.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten()[prev_draft_token_indices_tensor], + ) def _get_encoder_seq_lens( self, - scheduler_output: "SchedulerOutput", + num_scheduled_tokens: dict[str, int], kv_cache_spec: KVCacheSpec, num_reqs: int, - ) -> Optional[np.ndarray]: + ) -> tuple[torch.Tensor | None, np.ndarray | None]: if not isinstance(kv_cache_spec, CrossAttentionSpec): - return None + return None, None + # Zero out buffer for padding requests that are not actually scheduled (CGs) + self.encoder_seq_lens.np[:num_reqs] = 0 # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch - encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) - for req_id in scheduler_output.scheduled_encoder_inputs: + for req_id in num_scheduled_tokens: req_index = self.input_batch.req_id_to_index[req_id] - encoder_seq_lens[req_index] = self.max_encoder_len + req_state = self.requests[req_id] + if req_state.mm_features is None: + self.encoder_seq_lens.np[req_index] = 0 + continue + + # Get the total number of encoder input tokens for running encoder requests + # whether encoding is finished or not so that cross-attention knows how + # many encoder tokens to attend to. + encoder_input_tokens = sum( + feature.mm_position.length for feature in req_state.mm_features + ) + self.encoder_seq_lens.np[req_index] = encoder_input_tokens - return encoder_seq_lens + self.encoder_seq_lens.copy_to_gpu(num_reqs) + encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs] + encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs] + + return encoder_seq_lens, encoder_seq_lens_cpu def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" - ) -> tuple[PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray, - Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], - Optional[torch.Tensor]]: + self, + scheduler_output: "SchedulerOutput", + num_scheduled_tokens: np.ndarray, + ) -> tuple[ + torch.Tensor, + SpecDecodeMetadata | None, + ]: """ :return: tuple[ - attn_metadata: layer-to-attention_metadata mapping, - logits_indices, spec_decode_metadata + logits_indices, spec_decode_metadata, ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -971,55 +1343,58 @@ def _prepare_inputs( # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) - # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) - # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._calc_mrope_positions(scheduler_output) + # Calculate XD-RoPE positions. + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._calc_xdrope_positions(scheduler_output) + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) if self.enable_prompt_embeds: - is_token_ids = self.input_batch.is_token_ids.flatten() + is_token_ids = self.input_batch.is_token_ids_tensor.flatten() torch.index_select( is_token_ids, 0, token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) # Because we did not pre-allocate a massive prompt_embeds CPU tensor on # the InputBatch, we need to fill in the prompt embeds into the expected @@ -1053,78 +1428,66 @@ def _prepare_inputs( actual_num_sched = actual_end - start_pos if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) self.query_start_loc.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] - - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded) - uniform_decode = \ - (max_num_scheduled_tokens == self.uniform_decode_query_len) and \ - (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = \ - ubatch_split(num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config) + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) # Fill unused with 0 for full cuda graph mode. self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.copy_to_gpu() - seq_lens = self.seq_lens.gpu[:num_reqs] - max_seq_len = self.seq_lens.np[:num_reqs].max().item() - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] + + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) - # Record the index of requests that should not be sampled, + # Record which requests should not be sampled, # so that we could clear the sampled tokens before returning - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np - discard_request_indices = np.nonzero(discard_requests_mask)[0] - self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) - - self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + self.discard_request_mask.np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] < num_tokens_np + ) + self.discard_request_mask.copy_to_gpu(num_reqs) # Copy the tensors to the GPU. - self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self._prepare_input_ids( + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens, + ) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + non_blocking=True, + ) + elif self.uses_xdrope_dim > 0: + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -1134,6 +1497,7 @@ def _prepare_inputs( logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all @@ -1142,163 +1506,307 @@ def _prepare_inputs( # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens + ) logits_indices = spec_decode_metadata.logits_indices - + num_sampled_tokens = num_draft_tokens + 1 # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() - logits_indices_padded = None - if self.cache_config.kv_sharing_fast_prefill: - logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices) + # Hot-Swap lora model + if self.lora_config: + assert ( + np.sum(num_sampled_tokens) + <= self.vllm_config.scheduler_config.max_num_batched_tokens + ) + self.set_active_loras( + self.input_batch, num_scheduled_tokens, num_sampled_tokens + ) + + return ( + logits_indices, + spec_decode_metadata, + ) + + def _build_attention_metadata( + self, + num_tokens: int, + num_reqs: int, + max_query_len: int, + num_tokens_padded: int | None = None, + num_reqs_padded: int | None = None, + ubatch_slices: UBatchSlices | None = None, + logits_indices: torch.Tensor | None = None, + use_spec_decode: bool = False, + for_cudagraph_capture: bool = False, + num_scheduled_tokens: dict[str, int] | None = None, + cascade_attn_prefix_lens: list[list[int]] | None = None, + ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: + """ + :return: tuple[attn_metadata, spec_decode_common_attn_metadata] + """ + # Attention metadata is not needed for attention free models + if len(self.kv_cache_config.kv_cache_groups) == 0: + return {}, None + + num_tokens_padded = num_tokens_padded or num_tokens + num_reqs_padded = num_reqs_padded or num_reqs + assert num_reqs_padded is not None and num_tokens_padded is not None attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: attn_metadata = [dict() for _ in range(len(ubatch_slices))] - # Used in the below loop. - query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] - seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) - spec_decode_common_attn_metadata = None + if for_cudagraph_capture: + # For some attention backends (e.g. FA) with sliding window models we need + # to make sure the backend see a max_seq_len that is larger to the sliding + # window size when capturing to make sure the correct kernel is selected. + max_seq_len = self.max_model_len + else: + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - encoder_seq_lens = self._get_encoder_seq_lens( - scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) - - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): - # Encoder-only layers do not have KV cache, so we need to - # create a dummy block table and slot mapping for them. + kv_cache_groups = self.kv_cache_config.kv_cache_groups + + def _get_block_table_and_slot_mapping(kv_cache_gid: int): + assert num_reqs_padded is not None and num_tokens_padded is not None + kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): blk_table_tensor = torch.zeros( - (num_reqs, 1), + (num_reqs_padded, 1), dtype=torch.int32, device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), + (num_tokens_padded,), dtype=torch.int64, device=self.device, ) - num_common_prefix_blocks = 0 else: - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens] - - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( - -1) - num_common_prefix_blocks = ( - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id]) - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - max_seq_len=max_seq_len, - block_table_tensor=blk_table_tensor, - slot_mapping=slot_mapping, - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - causal=True, - encoder_seq_lens=encoder_seq_lens, + blk_table = self.input_batch.block_table[kv_cache_gid] + blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) + slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[num_tokens:num_tokens_padded].fill_(-1) + blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + + return blk_table_tensor, slot_mapping + + block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) + cm_base = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], + seq_lens=self.seq_lens.gpu[:num_reqs_padded], + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], + _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ], + num_reqs=num_reqs_padded, + num_actual_tokens=num_tokens_padded, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_gid_0, + slot_mapping=slot_mapping_gid_0, + causal=True, + ) + + if self.dcp_world_size > 1: + self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + self.seq_lens.cpu[:num_reqs], + self.dcp_world_size, + self.dcp_rank, + self.parallel_config.cp_kv_cache_interleave_size, + ) + self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0) + self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded) + + cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded] + cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[ + :num_reqs_padded + ] + + if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: + cm_base.num_logits_indices = logits_indices.size(0) + cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices + ) + + def _build_attn_group_metadata( + kv_cache_gid: int, + attn_gid: int, + common_attn_metadata: CommonAttentionMetadata, + ubid: int | None = None, + ) -> None: + attn_group = self.attn_groups[kv_cache_gid][attn_gid] + cascade_attn_prefix_len = ( + cascade_attn_prefix_lens[kv_cache_gid][attn_gid] + if cascade_attn_prefix_lens + else 0 + ) + + builder = attn_group.get_metadata_builder(ubid or 0) + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + assert ubid is None, "UBatching not supported with GDN yet" + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs_padded + ], + ) + + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + + if ubid is None: + assert isinstance(attn_metadata, dict) + attn_metadata_dict = attn_metadata + else: + assert isinstance(attn_metadata, list) + attn_metadata_dict = attn_metadata[ubid] + + for layer_name in attn_group.layer_names: + attn_metadata_dict[layer_name] = attn_metadata_i + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + spec_decode_common_attn_metadata = None + for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): + cm = copy(cm_base) # shallow copy + + # Basically only the encoder seq_lens, block_table and slot_mapping change + # for each kv_cache_group. + cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens( + num_scheduled_tokens or {}, + kv_cache_group.kv_cache_spec, + num_reqs_padded, ) + if kv_cache_gid > 0: + cm.block_table_tensor, cm.slot_mapping = ( + _get_block_table_and_slot_mapping(kv_cache_gid) + ) - if (self.speculative_config - and spec_decode_common_attn_metadata is None): + if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if (self.drafter.attn_layer_names[0] - in kv_cache_group_spec.layer_names): - spec_decode_common_attn_metadata = common_attn_metadata + if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: + spec_decode_common_attn_metadata = cm else: - spec_decode_common_attn_metadata = common_attn_metadata - - for attn_group in self.attn_groups[kv_cache_group_id]: - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - builder = attn_group.get_metadata_builder() - #### TODO(lms): support common prefix but needs know num of sms - # if self.cascade_attn_enabled: - # common_prefix_len = self._compute_cascade_attn_prefix_len( - # num_scheduled_tokens, - # num_common_prefix_blocks, - # attn_group.kv_cache_spec, - # builder, - # ) - - extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, - GDNAttentionMetadataBuilder): - extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], - ) + spec_decode_common_attn_metadata = cm + for attn_gid in range(len(self.attn_groups[kv_cache_gid])): if ubatch_slices is not None: - common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) - for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): - attn_metadata_i = (attn_group.get_metadata_builder( - ubatch_id=ubid).build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) - for layer_name in kv_cache_group_spec.layer_names: - assert type(attn_metadata) is list - attn_metadata[ubid][layer_name] = attn_metadata_i + for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)): + _build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid) + else: - assert isinstance(attn_metadata, dict) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i + _build_attn_group_metadata(kv_cache_gid, attn_gid, cm) - # Hot-Swap lora model - if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) + if self.is_mm_prefix_lm: + req_doc_ranges = {} + for req_id in self.input_batch.req_ids: + image_doc_ranges = [] + req_state = self.requests[req_id] + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + img_doc_range = pos_info.extract_embeds_range() + image_doc_ranges.extend(img_doc_range) + req_idx = self.input_batch.req_id_to_index[req_id] + req_doc_ranges[req_idx] = image_doc_ranges + + if isinstance(attn_metadata, list): + for ub_metadata in attn_metadata: + for _metadata in ub_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + else: + for _metadata in attn_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + + if spec_decode_common_attn_metadata is not None and ( + num_reqs != num_reqs_padded or num_tokens != num_tokens_padded + ): + # Currently the drafter still only uses piecewise cudagraphs (and modifies + # the attention metadata in directly), and therefore does not want to use + # padded attention metadata. + spec_decode_common_attn_metadata = ( + spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) + ) + + return attn_metadata, spec_decode_common_attn_metadata + + def _compute_cascade_attn_prefix_lens( + self, + num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, + num_common_prefix_blocks: list[int], + ) -> list[list[int]] | None: + """ + :return: Optional[cascade_attn_prefix_lens] + cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``, + None if we should not use cascade attention + """ + + use_cascade_attn = False + num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups) + cascade_attn_prefix_lens: list[list[int]] = [ + [] for _ in range(num_kv_cache_groups) + ] + + for kv_cache_gid in range(num_kv_cache_groups): + for attn_group in self.attn_groups[kv_cache_gid]: + if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec): + cascade_attn_prefix_len = 0 + else: + # 0 if cascade attention should not be used + cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + num_computed_tokens, + num_common_prefix_blocks[kv_cache_gid], + attn_group.kv_cache_spec, + attn_group.get_metadata_builder(), + ) + cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len) + use_cascade_attn |= cascade_attn_prefix_len > 0 + + return cascade_attn_prefix_lens if use_cascade_attn else None - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding) def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, num_common_prefix_blocks: int, kv_cache_spec: KVCacheSpec, attn_metadata_builder: AttentionMetadataBuilder, @@ -1364,21 +1872,22 @@ def _compute_cascade_attn_prefix_len( # and the second kernel will get an empty input. While this is not # a fundamental problem, our current implementation does not support # this case. - num_reqs = len(num_scheduled_tokens) - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + common_prefix_len = min(common_prefix_len, num_computed_tokens.min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + common_prefix_len = ( + common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size + ) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None + ) + use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None + ) assert isinstance(kv_cache_spec, AttentionSpec) + + ###TODO(lms): fix use of num_sms use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, @@ -1388,6 +1897,7 @@ def _compute_cascade_attn_prefix_len( use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, num_sms=self.num_sms, + dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0 @@ -1397,18 +1907,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - req.prompt_token_ids, req.prompt_embeds) + req.prompt_token_ids, req.prompt_embeds + ) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 @@ -1422,8 +1929,9 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) + self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ + :, src_start:src_end + ] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1431,6 +1939,7 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len + assert req.mrope_position_delta is not None MRotaryEmbedding.get_next_input_positions_tensor( out=self.mrope_positions.np, out_offset=dst_start, @@ -1441,19 +1950,66 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr += completion_part_len - def _calc_spec_decode_metadata( - self, - num_draft_tokens: np.ndarray, - cu_num_scheduled_tokens: np.ndarray, - ) -> SpecDecodeMetadata: - # Inputs: - # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] - # num_draft_tokens: [ 3, 0, 2, 0, 1] - # Outputs: - # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] - # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, - # 206, 207, 208] - # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"): + xdrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.xdrope_positions is not None + + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds + ) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's xdrope_positions are pre-computed + dst_start = xdrope_pos_ptr + dst_end = xdrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[ + :, src_start:src_end + ] + xdrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's xdrope_positions on-the-fly + dst_start = xdrope_pos_ptr + dst_end = xdrope_pos_ptr + completion_part_len + + XDRotaryEmbedding.get_next_input_positions_tensor( + out=self.xdrope_positions.np, + out_offset=dst_start, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + xdrope_pos_ptr += completion_part_len + + def _calc_spec_decode_metadata( + self, + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] # bonus_logits_indices: [ 3, 4, 7, 8, 10] # Compute the logits indices. @@ -1463,10 +2019,12 @@ def _calc_spec_decode_metadata( # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32) + num_sampled_tokens, cumsum_dtype=np.int32 + ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens + ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -1477,29 +2035,38 @@ def _calc_spec_decode_metadata( # cu_num_draft_tokens: [3, 3, 5, 5, 6] # arange: [0, 1, 2, 0, 1, 0] cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32) + num_draft_tokens, cumsum_dtype=np.int32 + ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens + ) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) + self.device, non_blocking=True + ) + cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( + self.device, non_blocking=True + ) + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] - metadata = SpecDecodeMetadata( + return SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, @@ -1507,7 +2074,6 @@ def _calc_spec_decode_metadata( bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) - return metadata def _prepare_kv_sharing_fast_prefill( self, @@ -1516,23 +2082,26 @@ def _prepare_kv_sharing_fast_prefill( assert self.kv_sharing_fast_prefill_logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices) # There might have leftover indices in logits_indices[num_logits:] # from previous iterations, whose values may be greater than the # batch size in the current iteration. To ensure indices are always # valid, we fill the padded indices with the last index. self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): + logits_indices[-1].item() + ) + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) else: num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ + :num_logits_padded + ] return logits_indices_padded def _batch_mm_kwargs_from_scheduler( @@ -1562,19 +2131,24 @@ def _batch_mm_kwargs_from_scheduler( for mm_input_id in encoder_input_ids: mm_feature = req_state.mm_features[mm_input_id] + if mm_feature.data is None: + continue mm_hash = mm_feature.identifier mm_kwargs.append(mm_feature.data) mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) return mm_kwargs, mm_hashes_pos - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _execute_mm_encoder( + self, scheduler_output: "SchedulerOutput" + ) -> list[torch.Tensor]: # Batch the multi-modal inputs using the helper method. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output) + scheduler_output + ) if not mm_kwargs: - return + return [] # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1584,30 +2158,42 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # multimodal inputs. The proper solution should be reordering the # encoder outputs. model = cast(SupportsMultiModal, self.model) - encoder_outputs = [] + encoder_outputs: list[torch.Tensor] = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, ): + curr_group_outputs: list[torch.Tensor] = [] + + # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when - # processing multimodal data.This solves the issue with scheduler + # processing multimodal data. This solves the issue with scheduler # putting too many video samples into a single batch. Scheduler # uses pruned vision tokens count to compare it versus compute # budget which is incorrect (Either input media size or non-pruned # output vision tokens count should be considered) - curr_group_outputs = [] - - if self.is_multimodal_pruning_enabled and modality == "video": - micro_batch_size = 1 - for i in range(0, num_items, micro_batch_size): - micro_batch_mm_inputs = dict( - (k, v[i:i + micro_batch_size]) - for k, v in mm_kwargs_group.items()) + # TODO(ywang96): Fix memory profiling to take EVS into account and + # remove this hack. + if ( + self.is_multimodal_pruning_enabled + and modality == "video" + and num_items > 1 + ): + for video_mm_kwargs_item in filter( + lambda item: item.modality == "video", mm_kwargs + ): + _, _, micro_batch_mm_inputs = next( + group_mm_kwargs_by_modality( + [video_mm_kwargs_item], + device=self.device, + pin_memory=self.pin_memory, + ) + ) - micro_batch_outputs = model.get_multimodal_embeddings( - **micro_batch_mm_inputs) + micro_batch_outputs = model.embed_multimodal( + **micro_batch_mm_inputs + ) curr_group_outputs.extend(micro_batch_outputs) else: @@ -1618,8 +2204,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -1629,26 +2214,34 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Cache the encoder outputs by mm_hash for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + self.encoder_cache[mm_hash] = output + logger.debug("Finish execute for mm hash %s", mm_hash) + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + + return encoder_outputs def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, - ) -> list[torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + mm_embeds = list[torch.Tensor]() + is_mm_embed = self.is_mm_embed.cpu + is_mm_embed[:total_num_scheduled_tokens] = False + + req_start_idx = 0 should_sync_mrope_positions = False - mm_embeds: list[torch.Tensor] = [] + should_sync_xdrope_positions = False + for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens + num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens + for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position start_pos = pos_info.offset @@ -1672,22 +2265,32 @@ def _gather_mm_embeddings( num_encoder_tokens, ) assert start_idx < end_idx + curr_embeds_start, curr_embeds_end = ( + pos_info.get_embeds_indices_in_range(start_idx, end_idx) + ) + # If there are no embeddings in the current range, we skip + # gathering the embeddings. + if curr_embeds_start == curr_embeds_end: + continue mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end] + else: + mm_embeds_item = encoder_output[start_idx:end_idx] - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed ) mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: + assert req_state.mrope_positions is not None should_sync_mrope_positions = True mm_embeds_req, new_mrope_positions, new_delta = ( self.model.recompute_mrope_positions( @@ -1695,54 +2298,29 @@ def _gather_mm_embeddings( multimodal_embeddings=mm_embeds_req, mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, - )) - assert req_state.mrope_positions is not None + ) + ) req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta mm_embeds.extend(mm_embeds_req) + req_start_idx += num_scheduled_tokens + + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) if should_sync_mrope_positions: self._calc_mrope_positions(scheduler_output) - self.mrope_positions.copy_to_gpu( - scheduler_output.total_num_scheduled_tokens) - - return mm_embeds - - def _extract_encoder_inputs( - self, - scheduler_output: "SchedulerOutput", - ) -> dict[str, torch.Tensor]: - """Extract encoder inputs for encoder-decoder models. - - This method extracts multimodal input features from scheduled encoder - inputs and formats them for the encoder-decoder model forward pass. - """ - # Batch the multi-modal inputs using the helper method. - mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) - if not mm_kwargs: - return {} - - # Group MM kwargs by modality and extract features - model = cast(SupportsMultiModal, self.model) - encoder_features = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - ): - # Add the grouped features to encoder_features dict - # This allows the model to receive them as kwargs (e.g., - # input_features=...) - encoder_features.update(mm_kwargs_group) + if should_sync_xdrope_positions: + self._calc_xdrope_positions(scheduler_output) + self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens) - return encoder_features + return mm_embeds, is_mm_embed def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. - if isinstance(self.model, (GraphWrapper)): #, UBatchWrapper)): + if isinstance(self.model, (GraphWrapper, UBatchWrapper)): return self.model.unwrap() return self.model @@ -1768,21 +2346,11 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): - supported_tasks.remove("encode") - - logger.debug_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") - if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: supported_tasks.remove("score") - logger.debug_once( - "Score API is only enabled for num_labels == 1.") + logger.debug_once("Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1797,9 +2365,11 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return tuple(tasks) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors | None, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size @@ -1811,21 +2381,21 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_rs - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) - - return IntermediateTensors({ - k: - v[:num_tokens // - tp] if k == "residual" and is_rs else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) - - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: + v[:copy_len], non_blocking=True + ) + + return IntermediateTensors( + { + k: v[: num_tokens // tp] + if k == "residual" and is_rs + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) + + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1836,97 +2406,28 @@ def eplb_step(self, model = self.get_model() assert is_mixture_of_experts(model) self.eplb_state.step( - model, is_dummy, is_profile, log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - """ - Determines the total number of tokens that each rank will run. - All ranks will be padded out so that they run with the same number - of tokens - - Returns: tuple[ - num_pad_tokens: The number of tokens that will be added to the batch - num_tokens_after_padding: A tensor containing the total number of - tokens for each DP rank including padding. - ] - """ - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use CUDA graphs (enabled by this padding) on the decoder. - # - # TODO(tms) : There are many cases where padding is enabled for - # prefills, causing unnecessary and excessive padding of activations. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - - def get_local_padding(self, num_tokens_unpadded: int) -> int: - - num_tokens_padded = num_tokens_unpadded - - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_tokens_unpadded) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_tokens_padded = round_up(num_tokens_unpadded, tp_size) - - num_pad_tokens = num_tokens_padded - num_tokens_unpadded - return num_pad_tokens - - # This is where the second ubatch is adjusted to account for the padding. - # Should be called after attention metadata creation. This just pads - # the second ubatch slice out to the total number of tokens - # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, - num_total_tokens: int): - padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, - num_total_tokens) - ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) - def _pool( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + "Either all or none of the requests in a batch must be pooling request" + ) hidden_states = hidden_states[:num_scheduled_tokens] + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + pooling_metadata = self.input_batch.get_pooling_metadata() - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device + ) model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -1934,15 +2435,15 @@ def _pool( pooling_metadata=pooling_metadata, ) raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True), + lambda x: x.to("cpu", non_blocking=True) if x is not None else x, raw_pooler_output, ) self._sync_device() - pooler_output: list[Optional[torch.Tensor]] = [] + pooler_output: list[torch.Tensor | None] = [] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens + ): output = raw_output if seq_len == prompt_len else None pooler_output.append(output) @@ -1955,65 +2456,55 @@ def _pool( pooler_output=pooler_output, ) - def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use CUDA graphs. - # Add padding to the batch size. - return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) - - # Eager mode. + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1): + if self.compilation_config.pass_config.enable_sp and tp_size > 1: return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens def _preprocess( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ubatch_slices: Optional[UBatchSlices] = None, - num_tokens_after_padding: Optional[torch.Tensor] = None, - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], dict[str, Any]]: - + num_input_tokens: int, # Padded + intermediate_tensors: IntermediateTensors | None = None, + ) -> tuple[ + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor, + IntermediateTensors | None, + dict[str, Any], + ECConnectorOutput | None, + ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if ubatch_slices: - assert num_tokens_after_padding is not None - num_input_tokens = int(num_tokens_after_padding[0].item() * 2) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif ubatch_slices is None: - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding( - num_input_tokens) - num_input_tokens += num_pad + is_first_rank = get_pp_group().is_first_rank + is_encoder_decoder = self.model_config.is_encoder_decoder # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if (self.supports_mm_inputs and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder): + ec_connector_output = None + + if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder: # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids.gpu[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, + inputs_embeds_scheduled = self.model.embed_input_ids( + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2021,7 +2512,7 @@ def _preprocess( **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: + elif self.enable_prompt_embeds and is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions. # TODO(qthequartermasterman): Since even when prompt embeds are @@ -2034,14 +2525,15 @@ def _preprocess( # If a batch only has token ids, then including the embedding layer # in the CUDA graph will be more performant (like in the else case # below). - token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ - .nonzero(as_tuple=False) \ + token_ids_idx = ( + self.is_token_ids.gpu[:num_scheduled_tokens] + .nonzero(as_tuple=False) .squeeze(1) + ) # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings( - input_ids=token_ids) + tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2055,92 +2547,90 @@ def _preprocess( input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) + if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_input_tokens] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_input_tokens] else: positions = self.positions.gpu[:num_input_tokens] - if get_pp_group().is_first_rank: + if is_first_rank: intermediate_tensors = None else: + assert intermediate_tensors is not None intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + num_input_tokens, intermediate_tensors, True + ) - if (self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs): - encoder_inputs = self._extract_encoder_inputs(scheduler_output) - model_kwargs.update(encoder_inputs) + if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + # Run the encoder, just like we do with other multimodal inputs. + # For an encoder-decoder model, our processing here is a bit + # simpler, because the outputs are just passed to the decoder. + # We are not doing any prompt replacement. We also will only + # ever have a single encoder input. + encoder_outputs = self._execute_mm_encoder(scheduler_output) + model_kwargs.update({"encoder_outputs": encoder_outputs}) return ( - num_scheduled_tokens, - num_input_tokens, - num_tokens_after_padding, input_ids, inputs_embeds, positions, intermediate_tensors, model_kwargs, + ec_connector_output, ) def _sample( - self, logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata] + self, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - sampler_output = self.sampler( + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() + return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - self._update_states_after_model_execute(output_token_ids) + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, + ) + self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( - self, scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, num_scheduled_tokens: int + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, ) -> tuple[ - dict[str, int], - Optional[LogprobsLists], - list[list[int]], - dict[str, Optional[LogprobsTensors]], - list[str], - dict[str, int], - list[int], + dict[str, int], + LogprobsLists | None, + list[list[int]], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] + num_reqs = self.input_batch.num_reqs + discard_sampled_tokens_req_indices = np.nonzero( + self.discard_request_mask.np[:num_reqs] + )[0] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -2149,52 +2639,42 @@ def _bookkeeping_sync( # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() - - # NOTE: GPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:num_scheduled_tokens], - scheduler_output.num_scheduled_tokens, - ) + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids + logprobs_tensors = sampler_output.logprobs_tensors invalid_req_indices = [] + cu_num_tokens: list[int] | None = None if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = self._to_list(sampled_token_ids) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() else: # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( + valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( sampled_token_ids, self.input_batch.vocab_size, + discard_sampled_tokens_req_indices, + return_cu_num_tokens=logprobs_tensors is not None, ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) - assert sampled_token_ids.shape[-1] == 1 # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set + # With spec decoding, this is done in propose_draft_token_ids(). + if self.input_batch.prev_sampled_token_ids is None: + assert sampled_token_ids.shape[-1] == 1 + self.input_batch.prev_sampled_token_ids = sampled_token_ids self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2209,22 +2689,24 @@ def _bookkeeping_sync( req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if \ - req_idx not in invalid_req_indices_set else None + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] + + num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 + if not sampled_ids: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) + end_idx = start_idx + num_sampled_ids assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}") + f"{self.max_model_len}" + ) - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2233,6 +2715,18 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + logprobs_lists = ( + logprobs_tensors.tolists(cu_num_tokens) + if not self.use_async_scheduling and logprobs_tensors is not None + else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + return ( num_nans_in_logits, logprobs_lists, @@ -2258,75 +2752,368 @@ def synchronize_input_prep(self): finally: self.prepare_inputs_event.record() + def _model_forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], + ) -> Any: + """Helper method to call the model forward pass. + + This method can be overridden by subclasses for model execution. + Motivation: We can inspect only this method versus + the whole execute_model, which has additional logic. + + Args: + input_ids: Input token IDs + positions: Token positions + intermediate_tensors: Tensors from previous pipeline stages + inputs_embeds: Input embeddings (alternative to input_ids) + **model_kwargs: Additional model arguments + + Returns: + Model output tensor + """ + return self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + def _determine_batch_execution_and_padding( + self, + num_tokens: int, + num_reqs: int, + num_scheduled_tokens_np: np.ndarray, + max_num_scheduled_tokens: int, + use_cascade_attn: bool, + allow_microbatching: bool = True, + force_eager: bool = False, + # For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will + # be improved in model runner v2) + force_uniform_decode: bool | None = None, + force_has_lora: bool | None = None, + num_encoder_reqs: int = 0, + ) -> tuple[ + CUDAGraphMode, + BatchDescriptor, + bool, + torch.Tensor | None, + CUDAGraphStat | None, + ]: + num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) + uniform_decode = ( + ( + (max_num_scheduled_tokens == self.uniform_decode_query_len) + and (num_tokens_padded == max_num_scheduled_tokens * num_reqs) + ) + if force_uniform_decode is None + else force_uniform_decode + ) + # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output + # is present). Also, chunked-prefill is disabled, so batch are uniform. + has_encoder_output = ( + self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + ) + + has_lora = ( + len(self.input_batch.lora_id_to_lora_request) > 0 + if force_has_lora is None + else force_has_lora + ) + + dispatch_cudagraph = ( + lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, + has_lora=has_lora, + uniform_decode=uniform_decode, + disable_full=disable_full, + ) + if not force_eager + else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) + ) + + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, use_cascade_attn or has_encoder_output + ) + num_tokens_padded = batch_descriptor.num_tokens + + # Extra coordination when running data-parallel since we need to coordinate + # across ranks + should_ubatch, num_tokens_across_dp = False, None + if self.vllm_config.parallel_config.data_parallel_size > 1: + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller + # in a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) + + should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( + coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, + parallel_config=self.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens_np, + cudagraph_mode=cudagraph_mode.value, + ) + ) + + # Extract DP-synced values + if num_tokens_across_dp is not None: + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) + # Re-dispatch with DP padding so we have the correct batch_descriptor + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, + disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value, + ) + # Assert to make sure the agreed upon token count is correct otherwise + # num_tokens_across_dp will no-longer be valid + assert batch_descriptor.num_tokens == num_tokens_padded + + cudagraph_stats = None + if self.vllm_config.observability_config.cudagraph_metrics: + cudagraph_stats = CUDAGraphStat( + num_unpadded_tokens=num_tokens, + num_padded_tokens=batch_descriptor.num_tokens, + num_paddings=batch_descriptor.num_tokens - num_tokens, + runtime_mode=str(cudagraph_mode), + ) + + return ( + cudagraph_mode, + batch_descriptor, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) + + def _register_layerwise_nvtx_hooks(self) -> None: + """ + Register layerwise NVTX hooks if --enable-layerwise-nvtx-tracing is enabled + to trace detailed information of each layer or module in the model. + """ + + if ( + self.vllm_config.observability_config.enable_layerwise_nvtx_tracing + and not self.layerwise_nvtx_hooks_registered + ): + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + logger.debug_once( + "layerwise NVTX tracing is not supported when CUDA graph is " + "turned off; you may observe part or all of the model " + "missing NVTX markers" + ) + + # In STOCK_TORCH_COMPILE mode, after registering hooks here, + # the __call__ function of nn.module will be recompiled with + # fullgraph=True. Since nvtx.range_push/pop are not traceable + # by torch dynamo, we can't register hook functions here + # because hook functions will also be traced by torch dynamo. + if ( + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + logger.debug_once( + "layerwise NVTX tracing is not supported when " + "CompilationMode is STOCK_TORCH_COMPILE, skipping " + "function hooks registration" + ) + else: + pyt_hooks = PytHooks() + pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) + self.layerwise_nvtx_hooks_registered = True + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: - with record_function_or_nullcontext("Preprocess"): + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | IntermediateTensors | None: + if self.execute_model_state is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) + + # self._draft_token_ids is None when `input_fits_in_drafter=False` + # and there is no draft tokens scheduled. so it need to update the + # spec_decoding info in scheduler_output with async_scheduling. + # use deepcopy to avoid the modification has influence on the + # scheduler_output in engine core process. + # TODO(Ronald1995): deepcopy is expensive when there is a large + # number of requests, optimize it later. + if ( + self.use_async_scheduling + and self.num_spec_tokens + and self._draft_token_ids is None + ): + scheduler_output = deepcopy(scheduler_output) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward( - scheduler_output, self.vllm_config) + scheduler_output, self.vllm_config + ) if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( + assert not self.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs") + "it when the requests need prompt logprobs" + ) + + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + + ( + logits_indices, + spec_decode_metadata, + ) = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) + + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, + ) + + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) - # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_after_padding - ) = self._prepare_inputs(scheduler_output) + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + (attn_metadata, spec_decode_common_attn_metadata) = ( + self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + ) + ) ( - num_scheduled_tokens, - num_input_tokens, - num_tokens_across_dp, input_ids, inputs_embeds, positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) - - uniform_decode = (max_query_len - == self.uniform_decode_query_len) and ( - num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) - - # This is currently to get around the assert in the DPMetadata - # where it wants `num_tokens_across_dp` to align with `num_tokens` - if ubatch_slices is not None: - num_input_tokens = ubatch_slices[0].num_tokens + ec_connector_output, + ) = self._preprocess( + scheduler_output, num_tokens_padded, intermediate_tensors + ) + + # Set cudagraph mode to none if calc_kv_scales is true. + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False # Run the model. # Use persistent buffers for CUDA graphs. - with (set_forward_context( + with ( + set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_input_tokens, + num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices, - ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as - kv_connector_output): - model_output = self.model( + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + ), + record_function_or_nullcontext("gpu_model_runner: forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + model_output = self._model_forward( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -2334,7 +3121,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("gpu_model_runner: postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2349,12 +3136,14 @@ def execute_model( # Return the intermediate tensors. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output return hidden_states if self.is_pooling_model: # Return the pooling output. - output = self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) output.kv_connector_output = kv_connector_output return output @@ -2364,42 +3153,95 @@ def execute_model( # Rare case. assert not self.is_pooling_model + sample_hidden_states = hidden_states[logits_indices] if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_tokens_padded + ) } get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + all_gather_tensors=all_gather_tensors, + ) logits = None else: - sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) - model_output_broadcast_data = {} + model_output_broadcast_data: dict[str, Any] = {} if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() + broadcasted = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert broadcasted is not None + logits = broadcasted["logits"] + + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + ) + self.kv_connector_output = kv_connector_output + return None + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + if not kv_connector_output: + return None # type: ignore[return-value] + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output - model_output_broadcast_data = get_pp_group( - ).broadcast_tensor_dict(model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask(scheduler_output, self.input_batch, - logits, self.device) + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ec_connector_output, + cudagraph_stats, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) - with record_function_or_nullcontext("Sample"): + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self.input_batch.prev_sampled_token_ids = None + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): + with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2411,28 +3253,51 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, ) - use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_drafter_batch + spec_config = self.speculative_config + use_padded_batch_for_eagle = ( + spec_config is not None + and spec_config.use_eagle() + and not spec_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len - if (self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len - is not None): + if ( + spec_config is not None + and spec_config.draft_model_config is not None + and spec_config.draft_model_config.max_model_len is not None + ): effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len) + spec_config.draft_model_config.max_model_len + ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.seq_lens.max() + - self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len) - if use_padded_batch_for_eagle and input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. - propose_draft_token_ids(sampler_output.sampled_token_ids) - - with record_function_or_nullcontext("Bookkeep"): + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + <= effective_drafter_max_model_len + ) + if use_padded_batch_for_eagle: + assert self.speculative_config is not None + assert isinstance(self.drafter, EagleProposer) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2441,42 +3306,69 @@ def propose_draft_token_ids(sampled_token_ids): req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync(scheduler_output, sampler_output, - logits, hidden_states, - num_scheduled_tokens) + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) - if (self.speculative_config and not use_padded_batch_for_eagle - and input_fits_in_drafter): + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("EPLB"): + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, + num_nans_in_logits=num_nans_in_logits, + cudagraph_stats=cudagraph_stats, + ) - output = ModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) - - ### TODO(lms): support async schedule + ### TODO(lms): abstract async schedule for all hardware if not self.use_async_scheduling: return output + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) - return AsyncGPUModelRunnerOutput( - model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, - invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, - ) + return async_output - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + def take_draft_token_ids(self) -> DraftTokenIds | None: if self._draft_token_ids is None: return None req_ids = self.input_batch.req_ids @@ -2487,27 +3379,66 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: self._draft_token_ids = None return DraftTokenIds(req_ids, draft_token_ids) + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor + ) -> None: + if self.valid_sampled_token_count_event is None: + return + + default_stream = current_platform.torch_device_fn.current_stream() + # Initialize a new stream to overlap the copy operation with + # prepare_input of draft model. + with torch.cuda.stream(self.valid_sampled_token_count_copy_stream): + self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore + counts = valid_sampled_tokens_count + counts_cpu = self.valid_sampled_token_count_cpu + counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) + self.valid_sampled_token_count_event.record() + + self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) + + def _get_valid_sampled_token_count(self) -> list[int]: + # Wait until valid_sampled_tokens_count is copied to cpu, + prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids + if ( + self.valid_sampled_token_count_event is None + or prev_sampled_token_ids is None + ): + return [] + + counts_cpu = self.valid_sampled_token_count_cpu + self.valid_sampled_token_count_event.synchronize() + return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: Union[torch.Tensor, list[list[int]]], + sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[list[torch.Tensor]], - spec_decode_metadata: Optional[SpecDecodeMetadata], + aux_hidden_states: list[torch.Tensor] | None, + spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> Union[list[list[int]], torch.Tensor]: + ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if self.speculative_config.method == "ngram": + spec_config = self.speculative_config + assert spec_config is not None + if spec_config.method == "ngram": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - sampled_token_ids, self.input_batch.req_ids, + sampled_token_ids, + self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs) - elif self.speculative_config.method == "medusa": + self.input_batch.spec_decode_unsupported_reqs, + ) + elif spec_config.method == "suffix": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, SuffixDecodingProposer) + draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) + elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2517,10 +3448,12 @@ def propose_draft_token_ids( else: indices = [] offset = 0 - assert spec_decode_metadata is not None + assert spec_decode_metadata is not None, ( + "No spec decode metadata for medusa" + ) for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) @@ -2530,79 +3463,101 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif self.speculative_config.use_eagle(): + elif spec_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - if self.speculative_config.disable_padded_drafter_batch: + if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" "padded-batch is disabled." + ) next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, - scheduler_output.num_scheduled_tokens) + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ + ) + next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, - self.discard_request_indices.gpu, - self.num_discarded_requests + self.discard_request_mask.gpu, ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) if spec_decode_metadata is None: token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - if self.speculative_config.disable_padded_drafter_batch: + if spec_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[token_indices] else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ + common_attn_metadata, token_indices_to_sample = ( self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, - valid_sampled_tokens_count) + valid_sampled_tokens_count, + ) + ) + total_num_tokens = common_attn_metadata.num_actual_tokens + # When padding the batch, token_indices is just a range + target_token_ids = self.input_ids.gpu[:total_num_tokens] + target_positions = self._get_positions(total_num_tokens) + if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None + target_hidden_states = torch.cat( + [h[:total_num_tokens] for h in aux_hidden_states], dim=-1 + ) + else: + target_hidden_states = hidden_states[:total_num_tokens] - target_token_ids = self.input_ids.gpu[token_indices] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[token_indices] - if self.use_aux_hidden_state_outputs: - assert aux_hidden_states is not None - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - mm_embeds = None if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -2612,16 +3567,18 @@ def propose_draft_token_ids( last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, + mm_embed_inputs=mm_embed_inputs, ) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -2631,121 +3588,208 @@ def load_model(self, eep_scale_up: bool = False) -> None: Args: eep_scale_up: the model loading is for elastic EP scale up. """ - logger.info("Starting to load model %s...", self.model_config.model) - if eep_scale_up: - from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) - num_logical_experts = global_expert_load.shape[1] - self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } - else: - global_expert_load = None - old_global_expert_indices = None - rank_mapping = None - - with DeviceMemoryProfiler() as m: - time_before_load = time.perf_counter() - model_loader = get_model_loader(self.load_config) - logger.info("Loading model from scratch...") - self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) - if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) - if hasattr(self, "drafter"): - logger.info("Loading drafter model...") - self.drafter.load_model(self.model) - if self.use_aux_hidden_state_outputs: - if supports_eagle3(self.model): - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: - raise RuntimeError( - "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") - time_after_load = time.perf_counter() - self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info_once( + "Starting to load model %s...", + self.model_config.model, + scope="global", + ) + global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( + EplbState.get_eep_state(self.parallel_config) + if eep_scale_up + else (None, None, None) + ) + + if self.parallel_config.enable_eplb: + self.eplb_state = EplbState(self.parallel_config, self.device) + eplb_models = 0 + try: + with DeviceMemoryProfiler() as m: + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config + ) + if self.lora_config: + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) + if hasattr(self, "drafter"): + logger.info_once("Loading drafter model...") + self.drafter.load_model(self.model) + if ( + hasattr(self.drafter, "model") + and is_mixture_of_experts(self.drafter.model) + and self.parallel_config.enable_eplb + ): + spec_config = self.vllm_config.speculative_config + assert spec_config is not None + assert spec_config.draft_model_config is not None + logger.info_once( + "EPLB is enabled for drafter model %s.", + spec_config.draft_model_config.model, + ) + + global_expert_load = ( + global_expert_loads[eplb_models] + if global_expert_loads + else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + if self.eplb_state is None: + self.eplb_state = EplbState( + self.parallel_config, self.device + ) + self.eplb_state.add_model( + self.drafter.model, + spec_config.draft_model_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + eplb_models += 1 + + if self.use_aux_hidden_state_outputs: + if not supports_eagle3(self.get_model()): + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested" + ) + + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers, + ) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + + self.model.set_aux_hidden_state_layers(aux_layers) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + except current_platform.torch_device_fn.OutOfMemoryError as e: + msg = ( + "Failed to load model - not enough GPU memory. " + "Try lowering --gpu-memory-utilization to free memory for weights, " + "increasing --tensor-parallel-size, or using --quantization. " + "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + "for more tips." + ) + combined_msg = f"{msg} (original error: {e})" + logger.error(combined_msg) + raise e + logger.info_once( + "Model loading took %.4f GiB memory and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + scope="local", + ) prepare_communication_buffer_for_model(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) + mm_config = self.model_config.multimodal_config + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.get_model()) + and mm_config is not None + and mm_config.is_multimodal_pruning_enabled() + ) - self.is_multimodal_pruning_enabled = (supports_multimodal_pruning( - self.model) and self.model_config.multimodal_config. - is_multimodal_pruning_enabled()) - - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) - self.eplb_state = EplbState.build( + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info_once("EPLB is enabled for model %s.", self.model_config.model) + global_expert_load = ( + global_expert_loads[eplb_models] if global_expert_loads else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + assert self.eplb_state is not None + self.eplb_state.add_model( self.model, - self.device, - self.parallel_config, + self.model_config, global_expert_load, old_global_expert_indices, rank_mapping, ) + if self.eplb_state.is_async: + self.eplb_state.start_async_loop(rank_mapping=rank_mapping) if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + and supports_dynamo() ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) - compilation_counter.dynamo_as_is_count += 1 + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) + compilation_counter.stock_torch_compile_count += 1 self.model.compile(fullgraph=True, backend=backend) return - # for other compilation levels, cudagraph behavior is controlled by + # for other compilation modes, cudagraph behavior is controlled by # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ - and not self.parallel_config.enable_dbo: - self.model = GraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) - ### TODO(lms): support batch wrapper - # elif self.parallel_config.enable_dbo: - # if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - # self.model = UBatchWrapper(self.model, self.vllm_config, - # CUDAGraphMode.FULL, self.device) - # else: - # self.model = UBatchWrapper(self.model, self.vllm_config, - # CUDAGraphMode.NONE, self.device) + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + if self.compilation_config.cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: + self.model = GraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + elif self.parallel_config.enable_dbo: + if cudagraph_mode.has_full_cudagraphs(): + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) + else: + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) + + def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: + """Extract Eagle3 auxiliary layer indices from speculative config. + + These indices specify which hidden states from the base model should + be used as auxiliary inputs for the Eagle3 drafter model during + speculative decoding. + + Returns: + Tuple of layer indices if found in draft model config, + None otherwise. + """ + if not (self.speculative_config and self.speculative_config.draft_model_config): + return None + + hf_config = self.speculative_config.draft_model_config.hf_config + if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"): + return None + + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + + return None def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model = self.get_model() - model_loader.load_weights(model, model_config=self.model_config) + model_loader.load_weights(self.get_model(), model_config=self.model_config) def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", ) -> None: - model = self.get_model() TensorizerLoader.save_model( - model, + self.get_model(), tensorizer_config=tensorizer_config, model_config=self.model_config, ) @@ -2754,19 +3798,22 @@ def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], - ) -> dict[str, Optional[LogprobsTensors]]: - num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + ) -> dict[str, LogprobsTensors | None]: + num_prompt_logprobs_dict = self.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu - prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - num_tokens = num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens.get(req_id) + if num_tokens is None: + # This can happen if the request was preempted in prefill stage. + continue # Get metadata for this request. request = self.requests[req_id] @@ -2776,7 +3823,8 @@ def _get_prompt_logprobs_dict( num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) @@ -2784,7 +3832,8 @@ def _get_prompt_logprobs_dict( # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( - num_prompt_tokens - 1, num_prompt_logprobs + 1) + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. @@ -2814,27 +3863,29 @@ def _get_prompt_logprobs_dict( # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() - prompt_hidden_states = hidden_states[offset:offset + num_logits] + prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. - tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( - logprobs, num_prompt_logprobs, tgt_token_ids) + logprobs, num_prompt_logprobs, tgt_token_ids + ) # Transfer GPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( - token_ids, non_blocking=True) - logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, - non_blocking=True) + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( - ranks, non_blocking=True) + ranks, non_blocking=True + ) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. @@ -2850,7 +3901,7 @@ def _get_prompt_logprobs_dict( def _get_nans_in_logits( self, - logits: Optional[torch.Tensor], + logits: torch.Tensor | None, ) -> dict[str, int]: try: if logits is None: @@ -2862,26 +3913,29 @@ def _get_nans_in_logits( req_index = self.input_batch.req_id_to_index[req_id] num_nans_in_logits[req_id] = ( int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + if num_nans_for_index is not None and req_index < logits.shape[0] + else 0 + ) return num_nans_in_logits except IndexError: return {} @contextmanager - def maybe_randomize_inputs(self, input_ids: torch.Tensor): + def maybe_randomize_inputs( + self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None + ): """ Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - during DP rank dummy run """ + dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 if not randomize_inputs: yield - else: - import functools + elif input_ids is not None: @functools.cache def rand_input_ids() -> torch.Tensor: @@ -2889,13 +3943,27 @@ def rand_input_ids() -> torch.Tensor: self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + ) - logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) + logger.debug_once("Randomizing dummy input_ids for DP Rank") + input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) + else: + + @functools.cache + def rand_inputs_embeds() -> torch.Tensor: + return torch.randn_like( + self.inputs_embeds.gpu, + ) + + assert inputs_embeds is not None + logger.debug_once("Randomizing dummy inputs_embeds for DP Rank") + inputs_embeds.copy_( + rand_inputs_embeds()[: inputs_embeds.size(0)], non_blocking=True + ) + yield + inputs_embeds.fill_(0) def _get_mm_dummy_batch( self, @@ -2917,20 +3985,20 @@ def _get_mm_dummy_batch( dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_items = [dummy_mm_item] * max_items_per_batch - model = cast(SupportsMultiModal, self.model) - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + ) + ) @torch.inference_mode() def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, allow_microbatching: bool = True, @@ -2938,6 +4006,8 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -2960,10 +4030,12 @@ def _dummy_run( create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run + activate_lora: If False, dummy_run is performed without LoRAs. """ - assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - } + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -2978,8 +4050,7 @@ def _dummy_run( # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -2990,19 +4061,17 @@ def _dummy_run( assert not uniform_decode # Create mixed batch: # first half decode tokens, second half one prefill - num_decode_tokens = num_tokens // 2 + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) num_prefill_tokens = num_tokens - num_decode_tokens num_reqs = num_decode_tokens + 1 # Create decode requests (1 token each) followed by prefill request - num_scheduled_tokens_list = [1] * num_decode_tokens + [ - num_prefill_tokens - ] + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: assert not create_mixed_batch - num_reqs = cdiv(num_tokens, max_query_len) + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len @@ -3014,130 +4083,109 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) - total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) - - ubatch_slices = None - num_tokens_after_padding = None - - # We currently only microbatch if the number of tokens is - # over a certain threshold. - if self.parallel_config.enable_dbo and allow_microbatching: - ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split( - num_scheduled_tokens, - total_num_scheduled_tokens, - total_num_scheduled_tokens, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config, + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) + + num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile + or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, ) - # Currently when DBO is enabled `ubatch_split` returns - # the num_tokens_after_padding for a single ubatch, but we have 2 - # TODO(sage,lucas): this is cruft that should be addressed in the - # padding refactor. - if ubatch_num_tokens_after_padding is not None: - num_tokens_after_padding = ubatch_num_tokens_after_padding * 2 - + ) - # If we failed to microbatch, currently need to resynchronize - # TODO(lucas,sage): we should be able to avoid this second sync by - # refactoring `get_dp_padding_ubatch` and `get_dp_padding` into - # a single `coordinate_batch_across_dp` function. - if num_tokens_after_padding is None: - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens_after_padding = num_tokens + num_pad + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cudagraph_mode else: - num_tokens_across_dp = num_tokens_after_padding - num_tokens_after_padding = int(num_tokens_after_padding[0].item()) + assert cudagraph_runtime_mode == _cudagraph_mode, ( + f"Cudagraph runtime mode mismatch in dummy_run. " + f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." + ) - attn_metadata: Optional[PerLayerAttnMetadata] = None + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded + ) + + attn_metadata: PerLayerAttnMetadata | None = None # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - attn_metadata = {} - # if ubatch_slices is not None: - # attn_metadata = [dict() for _ in range(len(ubatch_slices))] - if create_mixed_batch: # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. # TODO(luka) better system for describing dummy batches seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] else: - seq_lens = max_query_len + seq_lens = max_query_len # type: ignore[assignment] self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange( - num_scheduled_tokens) - self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], - seq_lens=self.seq_lens.gpu[:num_reqs], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch. - block_table[kv_cache_group_id].get_device_tensor(num_reqs), - slot_mapping=self.input_batch.block_table[ - kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) - for attn_group in self.attn_groups[kv_cache_group_id]: - # if ubatch_slices is not None: - # common_attn_metadata_list = split_attn_metadata( - # ubatch_slices, common_attn_metadata) - # for ubid, common_attn_metadata in enumerate( - # common_attn_metadata_list): - # assert common_attn_metadata.max_query_len == 1 - # attn_metadata_i = (attn_group\ - # .get_metadata_builder(ubatch_id=ubid)\ - # .build_for_cudagraph_capture(common_attn_metadata)) - # for layer_name in attn_group.layer_names: - # assert type(attn_metadata) is list - # attn_metadata[ubid][ - # layer_name] = attn_metadata_i - # else: - assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder()\ - .build_for_cudagraph_capture(common_attn_metadata) - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, remove_lora): - model_kwargs = self._init_model_kwargs(num_tokens) - if (self.supports_mm_inputs - and not self.model_config.is_encoder_decoder): + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, + for_cudagraph_capture=is_graph_capturing, + ) + + with self.maybe_dummy_run_with_lora( + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, + ): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_padded <= self.max_num_tokens + model_kwargs = self._init_model_kwargs(num_tokens_padded) + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } elif self.enable_prompt_embeds: input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens] - model_kwargs = self._init_model_kwargs(num_tokens) + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + model_kwargs = self._init_model_kwargs(num_tokens_padded) else: - input_ids = self.input_ids.gpu[:num_tokens] + input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None if self.uses_mrope: - positions = self.mrope_positions.gpu[:, :num_tokens] + positions = self.mrope_positions.gpu[:, :num_tokens_padded] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: - positions = self.positions.gpu[:num_tokens] + positions = self.positions.gpu[:num_tokens_padded] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -3147,42 +4195,34 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + ) + ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) - - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode)) \ - if not is_profile else (CUDAGraphMode.NONE, None) - if cudagraph_runtime_mode is not None: - # we allow forcing NONE when the dispatcher disagrees to support - # warm ups for cudagraph capture - assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode == _cg_mode, ( - f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") - else: - cudagraph_runtime_mode = _cg_mode - - # if ubatch_slices is not None: - # # Adjust values to reflect a single ubatch. - # # TODO(sage,lucas): this is cruft that should be addressed in - # # the padding refactor. - # num_tokens_after_padding = ubatch_slices[0].num_tokens - # if num_tokens_across_dp is not None: - # num_tokens_across_dp[:] = num_tokens_after_padding + num_tokens_padded, None, False + ) - with self.maybe_randomize_inputs(input_ids), set_forward_context( + if ubatch_slices_padded is not None: + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_padded = ubatch_slices_padded[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_padded + + with ( + self.maybe_randomize_inputs(input_ids, inputs_embeds), + set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_tokens_after_padding, + num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices): + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + ), + ): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3198,7 +4238,43 @@ def _dummy_run( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + # Eagle currently only supports PIECEWISE cudagraphs. + # Therefore only use cudagraphs if the main model uses PIECEWISE + # NOTE(lucas): this is a hack, need to clean up. + use_cudagraphs = ( + ( + is_graph_capturing + and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + ) + or ( + not is_graph_capturing + and cudagraph_runtime_mode != CUDAGraphMode.NONE + ) + ) and not self.speculative_config.enforce_eager + + # Note(gnovack) - We need to disable cudagraphs for one of the two + # lora cases when cudagraph_specialize_lora is enabled. This is a + # short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/28334 + if self.compilation_config.cudagraph_specialize_lora and activate_lora: + use_cudagraphs = False + + self.drafter.dummy_run( + num_tokens, + use_cudagraphs=use_cudagraphs, + is_graph_capturing=is_graph_capturing, + ) + + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real @@ -3211,7 +4287,10 @@ def _dummy_run( self.eplb_step(is_dummy=True, is_profile=is_profile) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - return hidden_states, hidden_states[logit_indices] + logit_indices_device = torch.from_numpy(logit_indices).to( + self.device, non_blocking=True + ) + return hidden_states, hidden_states[logit_indices_device] @torch.inference_mode() def _dummy_sampler_run( @@ -3226,8 +4305,7 @@ def _dummy_sampler_run( logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) + dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=dummy_tensors(0.5), @@ -3243,48 +4321,46 @@ def _dummy_sampler_run( presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), output_token_ids=[[] for _ in range(num_reqs)], + spec_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, logitsprocs=LogitsProcessors(), ) - try: - sampler_output = self.sampler(logits=logits, - sampling_metadata=dummy_metadata) + sampler_output = self.sampler( + logits=logits, sampling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " "`max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, self.device) + draft_token_ids, self.device + ) num_tokens = sum(len(ids) for ids in draft_token_ids) # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) - # NOTE(woosuk): Here, we should use int32 because the sampler uses - # int32 for bonus_token_ids. If the dtype mismatches, re-compilation - # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + logits = torch.randn( + num_tokens + num_reqs, + logits.shape[-1], + device=self.device, + dtype=logits.dtype, + ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, - target_logits, - bonus_token_ids, + logits, dummy_metadata, ) return sampler_output @@ -3309,9 +4385,9 @@ def _dummy_pooler_run_task( num_scheduled_tokens_list, device="cpu", ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) + dummy_token_ids = torch.zeros( + (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device + ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) @@ -3323,21 +4399,27 @@ def _dummy_pooler_run_task( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, pooling_params=[dummy_pooling_params] * num_reqs, + pooling_states=[PoolingStates() for i in range(num_reqs)], ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor( + num_scheduled_tokens_list, + seq_lens_cpu=dummy_prompt_lens, + device=hidden_states.device, + ) try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) + return model.pooler( + hidden_states=hidden_states, pooling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e @@ -3347,8 +4429,18 @@ def _dummy_pooler_run( hidden_states: torch.Tensor, ) -> PoolerOutput: # Find the task that has the largest output for subsequent steps + supported_pooling_tasks = self.get_supported_pooling_tasks() + + if not supported_pooling_tasks: + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) + output_size = dict[PoolingTask, float]() - for task in self.get_supported_pooling_tasks(): + for task in supported_pooling_tasks: # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) output_size[task] = sum(o.nbytes for o in output) @@ -3360,10 +4452,12 @@ def _dummy_pooler_run( def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - if self.model_config.multimodal_config.skip_mm_profiling: + mm_config = self.model_config.multimodal_config + if mm_config is not None and mm_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -3373,8 +4467,9 @@ def profile_run(self) -> None: # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -3392,39 +4487,21 @@ def profile_run(self) -> None: ) # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.embed_multimodal( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, expected_num_items=max_mm_items_per_batch, ) - - # NOTE: This happens when encoder cache needs to store - # the embeddings that encoder outputs are scattered onto. - # In this case we create dummy embeddings of size - # (encode_budget, hidden_size) and scatter encoder - # output into it. - encoder_output_shape = dummy_encoder_outputs[0].shape - if encoder_output_shape[0] < encoder_budget: - expanded_outputs = [] - for output in dummy_encoder_outputs: - expanded = output.new_zeros( - (encoder_budget, encoder_output_shape[-1])) - num_tokens = output.shape[0] - expanded[:num_tokens].copy_(output) - expanded_outputs.append(expanded) - - dummy_encoder_outputs = expanded_outputs - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + for i, output in enumerate(dummy_encoder_outputs): + self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True + ) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -3441,15 +4518,13 @@ def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to `NONE`") + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) return 0 - else: - self.initialize_cudagraph_capture() compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() - start_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] @contextmanager def freeze_gc(): @@ -3472,34 +4547,54 @@ def freeze_gc(): # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): + start_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None + + if self.lora_config: + if self.compilation_config.cudagraph_specialize_lora: + lora_cases = [True, False] + else: + lora_cases = [True] + else: + lora_cases = [False] if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - - compilation_cases = list(reversed(self.cudagraph_batch_sizes)) + # make sure we capture the largest batch size first + compilation_cases = list( + product(reversed(self.cudagraph_batch_sizes), lora_cases) + ) self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + ) # Capture full cudagraph for uniform decode batches if we # don't already have full mixed prefill-decode cudagraphs. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len + x + for x in self.cudagraph_batch_sizes + if max_num_tokens >= x >= self.uniform_decode_query_len ] compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + product(reversed(decode_cudagraph_batch_sizes), lora_cases) + ) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + ) + current_platform.torch_device_fn.synchronize() + end_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because @@ -3507,21 +4602,32 @@ def freeze_gc(): # after here. set_cudagraph_capturing_enabled(False) + # Lock workspace to prevent resizing during execution. + # Max workspace sizes should have been captured during warmup/profiling. + lock_workspace() + end_time = time.perf_counter() - end_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info_once( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + scope="local", + ) return cuda_graph_size - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode in [CUDAGraphMode.FULL, - CUDAGraphMode.PIECEWISE] + def _capture_cudagraphs( + self, + compilation_cases: list[tuple[int, bool]], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3530,22 +4636,26 @@ def _capture_cudagraphs(self, compilation_cases: list[int], disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + cudagraph_runtime_mode.name, + ), + ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens in compilation_cases: + for num_tokens, activate_lora in compilation_cases: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = self.parallel_config.enable_dbo \ - and cudagraph_runtime_mode == CUDAGraphMode.FULL \ - and uniform_decode \ + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, num_tokens=num_tokens, uniform_decode=uniform_decode, ) + ) for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. @@ -3553,29 +4663,34 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + activate_lora=activate_lora, + ) + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + activate_lora=activate_lora, + is_graph_capturing=True, + ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + assert len(self.attn_groups) == 0, "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] @@ -3583,10 +4698,11 @@ class AttentionGroupKey(NamedTuple): def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, - ) -> dict[AttentionGroupKey, list[str]]: + ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: + layer_type = cast(type[Any], AttentionLayerBase) layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, - kv_cache_group_spec.layer_names) + self.vllm_config, layer_type, kv_cache_group_spec.layer_names + ) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -3600,163 +4716,330 @@ def get_attn_backends_for_group( if layer_name in self.kv_sharing_fast_prefill_eligible_layers: attn_backend = create_fast_prefill_custom_backend( "FastPrefill", - attn_backend, + attn_backend, # type: ignore[arg-type] ) full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name] + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] key = (full_cls_name, layer_kv_cache_spec) - attn_backends[key] = AttentionGroupKey(attn_backend, - layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) attn_backend_layers[key].append(layer_name) - return { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - } + return ( + {attn_backends[k]: v for k, v in attn_backend_layers.items()}, + set(group_key.attn_backend for group_key in attn_backends.values()), + ) def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], + kv_cache_group_id: int, ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for (attn_backend, - kv_cache_spec), layer_names in attn_backends_map.items(): - attn_group = AttentionGroup.create_with_metadata_builders( + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): + attn_group = AttentionGroup( attn_backend, layer_names, kv_cache_spec, - self.vllm_config, - self.device, - num_metadata_builders=1 - if not self.parallel_config.enable_dbo else 2, + kv_cache_group_id, ) attn_groups.append(attn_group) return attn_groups + attention_backend_maps = [] + attention_backend_list = [] for kv_cache_group_spec in kv_cache_config.kv_cache_groups: attn_backends = get_attn_backends_for_group(kv_cache_group_spec) - self.attn_groups.append(create_attn_groups(attn_backends)) + attention_backend_maps.append(attn_backends[0]) + attention_backend_list.append(attn_backends[1]) + + # Resolve cudagraph_mode before actually initialize metadata_builders + self._check_and_update_cudagraph_mode( + attention_backend_list, kv_cache_config.kv_cache_groups + ) + + # Check if attention backend supports PCP&DCP and related features. + check_attention_cp_compatibility(self.vllm_config) + for i, attn_backend_map in enumerate(attention_backend_maps): + self.attn_groups.append(create_attn_groups(attn_backend_map, i)) + + def initialize_metadata_builders( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> None: + """ + Create the metadata builders for all KV cache groups and attn groups. + """ + for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)): + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_sizes[kv_cache_group_id] + if kv_cache_group_id < len(kernel_block_sizes) + else None, + num_metadata_builders=1 + if not self.parallel_config.enable_dbo + else 2, + ) # Calculate reorder batch threshold (if needed) + # Note (tdoublep): do this *after* constructing builders, + # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() - def initialize_cudagraph_capture(self) -> None: + def _check_and_update_cudagraph_mode( + self, + attention_backends: list[set[type[AttentionBackend]]], + kv_cache_groups: list[KVCacheGroupSpec], + ) -> None: + """ + Resolve the cudagraph_mode when there are multiple attention + groups with potential conflicting CUDA graph support. + Then initialize the cudagraph_dispatcher based on the resolved + cudagraph_mode. + """ min_cg_support = AttentionCGSupport.ALWAYS - min_cg_builder_name = None + min_cg_backend_name = None + + for attn_backend_set, kv_cache_group in zip( + attention_backends, kv_cache_groups + ): + for attn_backend in attn_backend_set: + builder_cls = attn_backend.get_builder_cls() - for attn_group in self._attn_group_iterator(): - builder = attn_group.get_metadata_builder() - if builder.cudagraph_support.value < min_cg_support.value: - min_cg_support = builder.cudagraph_support - min_cg_builder_name = builder.__class__.__name__ + cg_support = builder_cls.get_cudagraph_support( + self.vllm_config, kv_cache_group.kv_cache_spec + ) + if cg_support.value < min_cg_support.value: + min_cg_support = cg_support + min_cg_backend_name = attn_backend.__name__ # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None # check cudagraph for mixed batch is supported - if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ - and min_cg_support != AttentionCGSupport.ALWAYS: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") + if ( + cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL + and min_cg_support != AttentionCGSupport.ALWAYS + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_backend_name} backend (support: " + f"{min_cg_support})" + ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += "; please try cudagraph_mode=PIECEWISE, and "\ - "make sure compilation level is piecewise" + msg += ( + "; please try cudagraph_mode=PIECEWISE, and " + "make sure compilation mode is VLLM_COMPILE" + ) raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY + ) logger.warning(msg) # check that if we are doing decode full-cudagraphs it is supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and min_cg_support == AttentionCGSupport.NEVER): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") - if (self.compilation_config.level == CompilationLevel.PIECEWISE and - (self.compilation_config.splitting_ops_contain_attention() - or self.compilation_config.use_inductor_graph_partition)): - msg += "; setting cudagraph_mode=PIECEWISE because "\ + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_backend_name} backend (support: " + f"{min_cg_support})" + ) + if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and ( + self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition + ): + msg += ( + "; setting cudagraph_mode=PIECEWISE because " "attention is compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: - msg += "; setting cudagraph_mode=NONE because "\ + msg += ( + "; setting cudagraph_mode=NONE because " "attention is not compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is # supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_cg_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})") + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 + and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_backend_name} (support: {min_cg_support})" + ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if cudagraph_mode.has_full_cudagraphs() \ - and min_cg_support == AttentionCGSupport.NEVER: - raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" - f"support:{min_cg_support}) " - "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise") - - # Trigger cudagraph dispatching keys initialization here (after - # initializing attn backends). + if ( + cudagraph_mode.has_full_cudagraphs() + and min_cg_support == AttentionCGSupport.NEVER + ): + raise ValueError( + f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_backend_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation mode is VLLM_COMPILE" + ) + + # if we have dedicated decode cudagraphs, and spec-decode is enabled, + # we need to adjust the cudagraph sizes to be a multiple of the uniform + # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 + # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 + # Will be removed in the near future when we have separate cudagraph capture + # sizes for decode and mixed prefill-decode. + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + and self.uniform_decode_query_len > 1 + ): + self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( + self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size + ) + capture_sizes = self.compilation_config.cudagraph_capture_sizes + self.cudagraph_batch_sizes = ( + capture_sizes if capture_sizes is not None else [] + ) + + # Trigger cudagraph dispatching keys initialization after + # resolved cudagraph mode. + self.compilation_config.cudagraph_mode = cudagraph_mode self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len) + cudagraph_mode, self.uniform_decode_query_len + ) def calculate_reorder_batch_threshold(self) -> None: """ - Check that if any backends reorder batches; that the reordering - is compatible (e.g., decode threshold is the same) + Choose the minimum reorder batch threshold from all attention groups. + Backends should be able to support lower threshold then what they request + just may have a performance penalty due to that backend treating decodes + as prefills. + """ + min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) + + reorder_batch_thresholds: list[int | None] = [ + group.get_metadata_builder().reorder_batch_threshold + for group in self._attn_group_iterator() + ] + # If there are no attention groups (attention-free model) or no backend + # reports a threshold, leave reordering disabled. + if len(reorder_batch_thresholds) == 0: + self.reorder_batch_threshold = None + return + self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] + + @staticmethod + def select_common_block_size( + kv_manager_block_size: int, attn_groups: list[AttentionGroup] + ) -> int: + """ + Select a block size that is supported by all backends and is a factor of + kv_manager_block_size. + + If kv_manager_block_size is supported by all backends, return it directly. + Otherwise, return the max supported size. + + Args: + kv_manager_block_size: Block size of KV cache + attn_groups: List of attention groups + + Returns: + The selected block size + + Raises: + ValueError: If no valid block size found """ - for group in self._attn_group_iterator(): - attn_metadata_builder_i = group.get_metadata_builder() - - # check that if any backends reorder batches; that the reordering - # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) - if reorder_batch_threshold_i is not None: - if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: - raise ValueError( - f"Attention backend reorders decodes with " - f"threshold {reorder_batch_threshold_i} but other " - f"backend uses threshold " - f"{self.reorder_batch_threshold}") - else: - self.reorder_batch_threshold = reorder_batch_threshold_i - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def block_size_is_supported( + backends: list[type[AttentionBackend]], block_size: int + ) -> bool: + """ + Check if the block size is supported by all backends. + """ + for backend in backends: + is_supported = False + for supported_size in backend.get_supported_kernel_block_sizes(): + if isinstance(supported_size, int): + if block_size == supported_size: + is_supported = True + elif isinstance(supported_size, MultipleOf): + if block_size % supported_size.base == 0: + is_supported = True + else: + raise ValueError(f"Unknown supported size: {supported_size}") + if not is_supported: + return False + return True + + backends = [group.backend for group in attn_groups] + + # Case 1: if the block_size of kv cache manager is supported by all backends, + # return it directly + if block_size_is_supported(backends, kv_manager_block_size): + return kv_manager_block_size + + # Case 2: otherwise, the block_size must be an `int`-format supported size of + # at least one backend. Iterate over all `int`-format supported sizes in + # descending order and return the first one that is supported by all backends. + # Simple proof: + # If the supported size b is in MultipleOf(x_i) format for all attention + # backends i, and b a factor of kv_manager_block_size, then + # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will + # return kv_manager_block_size in case 1. + all_int_supported_sizes = set( + supported_size + for backend in backends + for supported_size in backend.get_supported_kernel_block_sizes() + if isinstance(supported_size, int) + ) + + for supported_size in sorted(all_int_supported_sizes, reverse=True): + if kv_manager_block_size % supported_size != 0: + continue + if block_size_is_supported(backends, supported_size): + return supported_size + raise ValueError(f"No common block size for {kv_manager_block_size}. ") + + def may_reinitialize_input_batch( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -3764,16 +5047,22 @@ def may_reinitialize_input_batch(self, Args: kv_cache_config: The KV cache configuration. + kernel_block_sizes: The kernel block sizes for each KV cache group. """ block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] - if block_sizes != [self.cache_config.block_size]: + + if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ + self.cache_config.block_size + ]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=max(self.max_model_len, self.max_encoder_len), @@ -3782,16 +5071,17 @@ def may_reinitialize_input_batch(self, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, + logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), + num_speculative_tokens=self.num_spec_tokens, ) def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -3801,12 +5091,12 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -3816,8 +5106,9 @@ def _allocate_kv_cache_tensors( if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: @@ -3829,10 +5120,54 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups + def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]: + """ + Generate kernel_block_sizes that matches each block_size. + + For attention backends that support virtual block splitting, + use the supported block sizes from the backend. + For other backends (like Mamba), use the same block size (no splitting). + + Args: + kv_cache_config: The KV cache configuration. + + Returns: + list[int]: List of kernel block sizes for each cache group. + """ + kernel_block_sizes = [] + for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # Pick an arbitrary one to dispatch. + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # all backends in the group. + attn_groups = self.attn_groups[kv_cache_gid] + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + selected_kernel_size = self.select_common_block_size( + kv_manager_block_size, attn_groups + ) + kernel_block_sizes.append(selected_kernel_size) + elif isinstance(kv_cache_spec, MambaSpec): + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append(kv_cache_spec.block_size) + else: + raise NotImplementedError( + f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" + ) + return kernel_block_sizes + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor], + kernel_block_sizes: list[int], ) -> dict[str, torch.Tensor]: """ Reshape the KV cache tensors to the desired shape and dtype. @@ -3841,6 +5176,7 @@ def _reshape_kv_cache_tensors( kv_cache_config: The KV cache config kv_cache_raw_tensors: The KV cache buffer of each layer, with correct size but uninitialized shape. + kernel_block_sizes: The kernel block sizes for each KV cache group. Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. @@ -3850,55 +5186,65 @@ def _reshape_kv_cache_tensors( for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend + if group.kv_cache_group_id == len(kernel_block_sizes): + # There may be a last group for layers without kv cache. + continue + kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True + num_blocks_per_kv_block = ( + kv_cache_spec.block_size // kernel_block_size + ) + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, + kernel_num_blocks, + kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype) + cache_dtype_str=self.cache_config.cache_dtype, + ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len( - kv_cache_shape) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple( - range(len(kv_cache_shape))) + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and # then permute it according to the stride order which could # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + kv_cache_spec.page_size_bytes // dtype_size + ) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) @@ -3922,7 +5268,8 @@ def _reshape_kv_cache_tensors( return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor]) -> None: + self, kv_caches: dict[str, torch.Tensor] + ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to (num_blocks, 2, ...). @@ -3935,50 +5282,76 @@ def _update_hybrid_attention_mamba_layout( kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if (isinstance(kv_cache_spec, AttentionSpec) - and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, \ - "Fail to determine whether the layout is " \ - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " f"a tensor of shape {kv_cache.shape}" + ) hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_(size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, - *kv_cache.stride()[2:])) + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. Args: kv_cache_config: The KV cache config + kernel_block_sizes: The kernel block sizes for each KV cache group. + Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ - # Initialize the memory buffer for KV cache - kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) - # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + + # Try creating KV caches optimized for kv-connector transfers + cache_dtype = self.cache_config.cache_dtype + if self.use_uniform_kv_cache(self.attn_groups, cache_dtype): + kv_caches, cross_layers_kv_cache, attn_backend = ( + self.allocate_uniform_kv_caches( + kv_cache_config, + self.attn_groups, + cache_dtype, + self.device, + kernel_block_sizes, + ) + ) + self.cross_layers_kv_cache = cross_layers_kv_cache + self.cross_layers_attn_backend = attn_backend + else: + # Fallback to the general case + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes + ) # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - num_attn_module = 2 \ - if self.model_config.hf_config.model_type == "longcat_flash" else 1 - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, num_attn_module) + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) return kv_caches def maybe_add_kv_sharing_layers_to_kv_cache_groups( - self, kv_cache_config: KVCacheConfig) -> None: + self, kv_cache_config: KVCacheConfig + ) -> None: """ Add layers that re-use KV cache to KV cache group of its target layer. Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` @@ -3997,12 +5370,10 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other # similar KV sharing setups, only the layers that generate KV caches # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) else: break @@ -4015,11 +5386,24 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) - kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + # The kernel block size for all KV cache groups. For example, if + # kv_cache_manager uses block_size 256 for a given group, but the attention + # backends for that group only supports block_size 64, we will return + # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 + # tokens each. + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + # create metadata builders + self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) + + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) + kv_caches = self.initialize_kv_cache_tensors( + kv_cache_config, kernel_block_sizes + ) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) @@ -4028,30 +5412,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.drafter.validate_same_kv_cache_group(kv_cache_config) if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) - if self.device.type == 'xpu': - get_kv_transfer_group().set_host_xfer_buffer_ops( - copy_kv_blocks) - - if self.dcp_world_size > 1: - layer_names = self.attn_groups[0][0].layer_names - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) - for layer in layers.values(): - assert layer.impl.need_to_return_lse_for_decode, ( - "DCP requires attention impls to return" - " the softmax lse for decode, but the impl " - f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode.") + kv_transfer_group = get_kv_transfer_group() + if self.cross_layers_kv_cache is not None: + assert self.cross_layers_attn_backend is not None + kv_transfer_group.register_cross_layers_kv_cache( + self.cross_layers_kv_cache, self.cross_layers_attn_backend + ) + else: + kv_transfer_group.register_kv_caches(kv_caches) + kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: @@ -4059,16 +5435,18 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -4078,15 +5456,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ - - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - cache_dtype_str = self.vllm_config.cache_config.cache_dtype + if has_ec_transfer() and get_ec_transfer().is_producer: + return {} kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + layer_type = cast(type[Any], AttentionLayerBase) + attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if isinstance(attn_module, Attention) and ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ): # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -4096,86 +5474,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # or enable more requests to be processed simultaneously. self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for sliding" \ - "window" - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window) - elif use_mla: - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. - for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), - ) - ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache) - for layer_name, ds_indexer_module in ds_indexer_layers.items(): - kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() + # Skip modules that don't need KV cache (eg encoder-only attention) + if spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec return kv_cache_spec @@ -4188,7 +5489,7 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # this is in the critical path of every single model # forward loop, this has caused perf issue for a disagg # setup. - pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]] pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index cbeb0251..ce442adb 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -8,27 +8,42 @@ import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional, Union +from types import NoneType +from typing import TYPE_CHECKING, Any, Optional, cast +import numpy as np import torch import torch.distributed import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode,VllmConfig +from vllm.config.compilation import CompilationMode from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.distributed.parallel_state import get_pp_group, get_tp_group + init_distributed_environment) +from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group, +) +from vllm.distributed.parallel_state import ( + get_pcp_group, + get_pp_group, + get_tp_group, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform +from vllm.profiler.wrapper import TorchProfilerWrapper + from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask -from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling +from vllm.utils.mem_utils import GiB_bytes, MemorySnapshot, memory_profiling +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, @@ -36,6 +51,8 @@ from vllm.v1.utils import report_usage_stats from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.worker.workspace import init_workspace_manager + logger = init_logger(__name__) @@ -68,7 +85,7 @@ def __init__( if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules + from vllm.utils.import_utils import init_cached_hf_modules init_cached_hf_modules() # Buffers saved before sleep @@ -76,35 +93,24 @@ def __init__( # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - logger.debug( - "Profiler config: record_shapes=%s," - "profile_memory=%s,with_stack=%s,with_flops=%s", - envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - envs.VLLM_TORCH_PROFILER_WITH_STACK, - envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + self.profiler: Any | None = None + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": + worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" + self.profiler = TorchProfilerWrapper( + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU", "CUDA"], ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) else: self.profiler = None + register_oot_ops() - flag_gems.enable(record=False, unused=["index", "index_put_"]) + flag_gems.enable(record=False) #, unused=["index", "index_put_"]) # def sleep(self, level: int = 1) -> None: + # TODO(lms): rewrite CuMemAllocator # from vllm.device_allocator.cumem import CuMemAllocator # free_bytes_before_sleep = torch.cuda.mem_get_info()[0] @@ -164,6 +170,37 @@ def initialize_cache(self, num_gpu_blocks: int, def init_device(self): # This env var set by Ray causes exceptions with graph building. + if ( + self.parallel_config.data_parallel_size > 1 + and self.parallel_config.data_parallel_size_local > 0 + and self.parallel_config.distributed_executor_backend + not in ["ray", "external_launcher"] + and self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1 + ): + # Use local DP rank if available, otherwise use global DP rank. + dp_local_rank = self.parallel_config.data_parallel_rank_local + if dp_local_rank is None: + dp_local_rank = self.parallel_config.data_parallel_rank + + tp_pp_world_size = ( + self.parallel_config.pipeline_parallel_size + * self.parallel_config.tensor_parallel_size + ) + + # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK + self.local_rank += dp_local_rank * tp_pp_world_size + assert self.local_rank < torch.cuda.device_count(), ( + f"DP adjusted local rank {self.local_rank} is out of bounds. " + ) + visible_device_count = ( + current_platform.torch_device_fn.device_count() if current_platform.torch_device_fn.is_available() else 0 + ) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count})." + ) self.device = torch.device(f"{current_platform.device_type}:{self.local_rank}") current_platform.set_device(self.device) @@ -173,7 +210,8 @@ def init_device(self): # memory snapshot # This ensures NCCL buffers are allocated before we measure # available memory - init_worker_distributed_environment(self.vllm_config, self.rank, + init_worker_distributed_environment(self.vllm_config, + self.rank, self.distributed_init_method, self.local_rank, current_platform.dist_backend) @@ -185,7 +223,7 @@ def init_device(self): gc.collect() current_platform.empty_cache() - ### TODO(lms): support MemorySnapshot in other platform + ### TODO(lms): patch MemorySnapshot in other platform # take current memory snapshot self.init_snapshot = MemorySnapshot() self.requested_memory = (self.init_snapshot.total_memory * @@ -201,6 +239,9 @@ def init_device(self): f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " f"utilization or reduce GPU memory used by other processes." ) + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) # Construct the model runner self.model_runner = ModelRunnerFL( @@ -246,10 +287,10 @@ def determine_available_memory(self) -> int: self.model_runner.profile_run() msg = ( - f"Initial free memory {GiB(self.init_snapshot.free_memory)} " - f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for " "KV Cache as specified by kv_cache_memory_bytes config and " - "skipped memory profiling. This does does not respect the " + "skipped memory profiling. This does not respect the " "gpu_memory_utilization config. Only use kv_cache_memory_bytes " "config when you want manual control of KV cache memory " "size. If OOM'ed, check the difference of initial free " @@ -265,8 +306,8 @@ def determine_available_memory(self) -> int: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( - self.init_snapshot, - weights_memory=int(self.model_runner.model_memory_usage), + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), ) as profile_result: self.model_runner.profile_run() @@ -303,17 +344,39 @@ def determine_available_memory(self) -> int: GiB(free_gpu_memory - unrequested_memory), ) logger.debug(profile_result) - logger.info("Available KV cache memory: %.2f GiB", - GiB(self.available_kv_cache_memory_bytes)) + logger.info_once("Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + scope="local",) gc.collect() return int(self.available_kv_cache_memory_bytes) + + def get_kv_connector_handshake_metadata(self) -> dict | None: + """Get KV connector metadata from this worker if available.""" + + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + # Return None for connectors that don't need to exchange handshake + # metadata across workers. + if (metadata := connector.get_handshake_metadata()) is None: + return None + tp_rank = get_tp_group().rank_in_group + return {tp_rank: metadata} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + # Init kv cache connector here, because it requires + # `kv_cache_config`. + # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, + # because `initialize_kv_cache` will inject kv cache groups not + # related to kv cache connector (e.g. kv cache sharing layers). + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) + ### TODO(lms): # if self.vllm_config.model_config.enable_sleep_mode: # from vllm.device_allocator.cumem import CuMemAllocator @@ -322,21 +385,33 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: # context = allocator.use_memory_pool(tag="kv_cache") # else: # context = nullcontext() - context = nullcontext() - with context: - self.model_runner.initialize_kv_cache(kv_cache_config) + self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: - # warm up sizes that are not in cudagraph capture sizes, - # but users still want to compile for better performance, - # e.g. for the max-num-batched token size in chunked prefill. - compile_sizes = self.vllm_config.compilation_config.compile_sizes - warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - if not self.model_config.enforce_eager: - warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes - ] + warmup_sizes = [] + if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + compile_sizes = self.vllm_config.compilation_config.compile_sizes + warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + cg_capture_sizes: list[int] = [] + + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_capture_sizes) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) @@ -345,12 +420,67 @@ def compile_or_warm_up_model(self) -> None: remove_lora=False) self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) - ### NOTE(lms): can add gems kernel autotune here + ### NOTE(lms): can add gems kernel pretune here cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model() - ### TODO(lms): add kvcache limit compute + if self.cache_config.kv_cache_memory_bytes is None and hasattr( + self, "peak_activation_memory" + ): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + non_kv_cache_memory = ( + self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes + ) + kv_cache_memory_bytes_to_gpu_limit = ( + self.init_snapshot.free_memory + - non_kv_cache_memory + - redundancy_buffer_memory + ) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) + - non_kv_cache_memory + - redundancy_buffer_memory + ) + + msg = ( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` " + f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit " + f"into requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` " + f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{GiB(self.available_kv_cache_memory_bytes)} GiB." + ) + + logger.debug(msg) # Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory @@ -366,6 +496,7 @@ def compile_or_warm_up_model(self) -> None: self.model_runner._dummy_run( num_tokens=max_num_reqs, skip_eplb=True, + cudagraph_runtime_mode=CUDAGraphMode.NONE, ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) @@ -377,37 +508,87 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + + def annotate_profile(self, scheduler_output): + # add trace annotation so that we can easily distinguish + # new/cached request numbers in each iteration + if not self.profiler: + return nullcontext() + + self.profiler.step() + + num_new = len(scheduler_output.scheduled_new_reqs) + num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids) + + return self.profiler.annotate_context_manager( + f"execute_new_{num_new}_cached_{num_cached}" + ) + @torch.inference_mode() + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output) + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: + ) -> ModelRunnerOutput | None: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.model_runner._get_num_input_tokens( - num_scheduled_tokens) - all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp(self.vllm_config, - num_input_tokens) - } + all_gather_tensors = {} + compilation_config = self.vllm_config.compilation_config + parallel_config = self.vllm_config.parallel_config + + if ( + parallel_config.pipeline_parallel_size > 1 + and compilation_config.pass_config.enable_sp + and forward_pass + ): + num_scheduled_tokens_np = np.array( + list(scheduler_output.num_scheduled_tokens.values()), + dtype=np.int32, + ) + # TODO(lucas): This is pretty gross; ideally we should only ever call + # `_determine_batch_execution_and_padding` once (will get called again + # in `execute_model`) but this requires a larger refactor of PP. + _, batch_desc, _, _, _ = ( + self.model_runner._determine_batch_execution_and_padding( + num_tokens=num_scheduled_tokens, + num_reqs=len(num_scheduled_tokens_np), + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=num_scheduled_tokens_np.max(), + use_cascade_attn=False, # TODO(lucas): Handle cascade attention + ) + ) + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp(self.vllm_config, + batch_desc.num_tokens) + } if forward_pass and not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors)) + tensor_dict = get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + assert tensor_dict is not None + intermediate_tensors = IntermediateTensors(tensor_dict) - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) - if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): - return output + with self.annotate_profile(scheduler_output): + output = self.model_runner.execute_model( + scheduler_output, intermediate_tensors + ) + if isinstance(output, (ModelRunnerOutput, NoneType)): + return output assert isinstance(output, IntermediateTensors) parallel_config = self.vllm_config.parallel_config @@ -418,34 +599,18 @@ def execute_model( all_gather_group=get_tp_group(), all_gather_tensors=all_gather_tensors) - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + return None def take_draft_token_ids(self) -> Optional[DraftTokenIds]: return self.model_runner.take_draft_token_ids() def profile(self, is_start: bool = True): if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") + raise RuntimeError("Profiling is not enabled.") if is_start: self.profiler.start() else: self.profiler.stop() - # only print profiler results on rank 0 - if self.local_rank == 0: - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1, uniform_decode=True) @@ -479,15 +644,17 @@ def _eplb_before_scale_down(self, old_ep_size: int, assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange(self.model_runner.model, execute_shuffle=True, - global_expert_load=None, + global_expert_loads=None, rank_mapping=rank_mapping) current_platform.torch_device_fn.synchronize() if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") def _eplb_after_scale_up( - self, old_ep_size: int, new_ep_size: int, - global_expert_load: Optional[torch.Tensor]) -> None: + self, + old_ep_size: int, + new_ep_size: int, + global_expert_loads: Optional[torch.Tensor]) -> None: from vllm.distributed.parallel_state import get_ep_group if get_ep_group().rank == 0: logger.info("[Elastic EP] Starting expert resharding " @@ -498,9 +665,8 @@ def _eplb_after_scale_up( } assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=True, - global_expert_load=global_expert_load, + global_expert_loads=global_expert_loads, rank_mapping=rank_mapping) if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") @@ -511,20 +677,25 @@ def _reconfigure_parallel_config( Update parallel config with provided reconfig_request """ parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank - if reconfig_request.new_data_parallel_rank_local != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank_local = \ + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( reconfig_request.new_data_parallel_rank_local - parallel_config.data_parallel_master_ip = \ + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) def _reconfigure_moe(self, old_ep_size: int, new_ep_size: int) -> Optional[torch.Tensor]: @@ -535,29 +706,58 @@ def _reconfigure_moe(self, old_ep_size: int, otherwise None """ from vllm.distributed.parallel_state import ( - get_dp_group, get_ep_group, prepare_communication_buffer_for_model) + get_dp_group, + get_ep_group, + prepare_communication_buffer_for_model) from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEParallelConfig) parallel_config = self.vllm_config.parallel_config - moe_modules = [ - module for module in self.model_runner.model.modules() - if (module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE") - ] - num_local_experts = moe_modules[0].moe_config.num_local_experts - assert all(module.moe_config.num_local_experts == num_local_experts - for module in moe_modules), ( - "All MoE modules must have the same number of experts") - for module in moe_modules: - module.moe_config.num_experts = num_local_experts * new_ep_size - module.global_num_experts = module.moe_config.num_experts - module.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=get_tp_group().world_size, - dp_size_=get_dp_group().world_size, - vllm_parallel_config=parallel_config, - ) - module.moe_config.moe_parallel_config = module.moe_parallel_config + def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]: + return [ + module + for module in model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) + ] + + def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int): + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + pcp_size_=get_pcp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + return moe_modules + + model_moe_modules = get_moe_modules(self.model_runner.model) + num_local_experts = model_moe_modules[0].moe_config.num_local_experts + + update_moe_modules(model_moe_modules, num_local_experts) + drafter_model = None + if hasattr(self.model_runner, "drafter") and hasattr( + self.model_runner.drafter, "model" + ): + drafter_model = self.model_runner.drafter.model + if drafter_model is not None and is_mixture_of_experts(drafter_model): + drafter_moe_modules = get_moe_modules(drafter_model) + # Check if drafter and model have matching configs + assert ( + drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts + ), "Drafter and model configs should be the same" + update_moe_modules(drafter_moe_modules, num_local_experts) + if new_ep_size < old_ep_size: num_local_physical_experts = num_local_experts assert self.model_runner.eplb_state is not None @@ -566,44 +766,55 @@ def _reconfigure_moe(self, old_ep_size: int, parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1]) - global_expert_load = None + global_expert_loads = None else: - num_local_physical_experts = torch.tensor([num_local_experts], + num_local_physical_experts_tensor = torch.tensor([num_local_experts], dtype=torch.int32, device="cpu") - torch.distributed.broadcast(num_local_physical_experts, + torch.distributed.broadcast(num_local_physical_experts_tensor, group=get_ep_group().cpu_group, group_src=0) - num_local_physical_experts = num_local_physical_experts.item() + num_local_physical_experts = int(num_local_physical_experts_tensor.item()) new_physical_experts = num_local_physical_experts * new_ep_size assert self.model_runner.eplb_state is not None - global_expert_load = self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=False) + global_expert_loads_any = self.model_runner.eplb_state.rearrange( + execute_shuffle=False + ) + global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_load.shape[1]) + new_physical_experts - global_expert_loads[0].shape[1]) prepare_communication_buffer_for_model(self.model_runner.model) + if drafter_model is not None: + prepare_communication_buffer_for_model(drafter_model) self.model_runner.model.update_physical_experts_metadata( num_physical_experts=new_physical_experts, num_local_physical_experts=num_local_physical_experts) - return global_expert_load + return global_expert_loads def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest) -> None: from vllm.config import set_current_vllm_config from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_ep_group) + cleanup_dist_env_and_memory, + get_ep_group, + ) old_ep_size = get_ep_group().world_size old_ep_rank = get_ep_group().rank - new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( - ).world_size * get_pp_group().world_size + new_ep_size = ( + reconfig_request.new_data_parallel_size + * get_tp_group().world_size + * get_pp_group().world_size + ) if new_ep_size < old_ep_size: self._eplb_before_scale_down(old_ep_size, new_ep_size) cleanup_dist_env_and_memory() - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): assert old_ep_rank >= new_ep_size # shutdown return @@ -611,16 +822,16 @@ def reinitialize_distributed( self._reconfigure_parallel_config(reconfig_request) with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment(self.vllm_config, self.rank, + init_worker_distributed_environment(self.vllm_config, + self.rank, self.distributed_init_method, self.local_rank) - global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) + global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size) if new_ep_size > old_ep_size: - assert global_expert_load is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, - global_expert_load) + assert global_expert_loads is not None + self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads) def save_sharded_state( self, @@ -646,6 +857,8 @@ def save_tensorized_model( def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() + if self.profiler is not None: + self.profiler.shutdown() def init_worker_distributed_environment( @@ -656,15 +869,24 @@ def init_worker_distributed_environment( backend: str = "nccl", ) -> None: """Initialize the distributed environment.""" + attention_config = vllm_config.attention_config parallel_config = vllm_config.parallel_config - # set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + + init_batch_invariance(attention_config.backend) - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, backend) + init_method = distributed_init_method or "env://" + init_distributed_environment( + parallel_config.world_size, rank, init_method, local_rank, backend + ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, + parallel_config.prefill_context_parallel_size, parallel_config.decode_context_parallel_size) - ensure_kv_transfer_initialized(vllm_config) + # Init ec connector here before KV caches caches init + # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode + ensure_ec_transfer_initialized(vllm_config) + From 25e1cd44618c97ff20f0231cb8b3bad729c371fe Mon Sep 17 00:00:00 2001 From: mslv Date: Wed, 31 Dec 2025 20:40:11 +0800 Subject: [PATCH 02/34] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 28d12bd2..cba4c8fb 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ A vLLM plugin built on the FlagOS unified multi-chip backend. 3.1 Clone the repository: ```sh git clone GitHub - flagos-ai/FlagCX - git checkout -b v0.3.0 + git checkout -b v0.7.0 ``` 3.2 Build the library with different flags targeting to different platforms: @@ -50,7 +50,7 @@ A vLLM plugin built on the FlagOS unified multi-chip backend. export FLAGCX_PATH=${pwd} ``` - 3.4 Installation FlagGems + 3.4 Installation FlagCX ```sh cd FlagCX/plugin/torch/ python setup.py develop --adaptor nvidia/ascend From 8e62ce23194ed24bf579d2ad490466d85d3596a8 Mon Sep 17 00:00:00 2001 From: mslv Date: Sun, 4 Jan 2026 22:04:01 +0800 Subject: [PATCH 03/34] polish code --- vllm_fl/attention/custom_attention.py | 4 ++-- vllm_fl/platform.py | 34 +++++++++++++++++++++------ vllm_fl/worker/worker.py | 21 ++++++++++------- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/vllm_fl/attention/custom_attention.py b/vllm_fl/attention/custom_attention.py index 060871a5..fccd07de 100644 --- a/vllm_fl/attention/custom_attention.py +++ b/vllm_fl/attention/custom_attention.py @@ -5,7 +5,7 @@ def register_attention(): register_backend( - backend=AttentionBackendEnum.FLASH_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, class_path="vllm_fl.attention.attention.AttentionFLBackend", is_mamba=False, - ) \ No newline at end of file + ) diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 7c5c3a70..36a5416c 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -12,10 +12,11 @@ import torch -from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.backends.registry import AttentionBackendEnum, register_backend from vllm.logger import init_logger from vllm.platforms import Platform, PlatformEnum +from vllm.platforms.interface import DeviceCapability if TYPE_CHECKING: from vllm.attention.selector import AttentionSelectorConfig @@ -42,7 +43,10 @@ def _get_backend( raise NotImplementedError("NOT support mla now!") # return "vllm_fl.attention.backends.mla.MLAFLBackend" else: - return AttentionBackendEnum.FLASH_ATTN #"vllm_fl.attention.attention.AttentionFLBackend" + if "USE_FLAGGEMS" in os.environ and os.environ["USE_FLAGGEMS"] == "1": + return [AttentionBackendEnum.TRITON_ATTN] #"vllm_fl.attention.attention.AttentionFLBackend" + return [AttentionBackendEnum.FLASH_ATTN] + class PlatformFL(Platform): @@ -60,6 +64,10 @@ class PlatformFL(Platform): def is_cuda_alike(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" return self.device_type == "cuda" + + def is_cuda(self) -> bool: + """Stateless version of [torch.cuda.is_available][].""" + return self.device_type == "cuda" @property def supported_dtypes(self) -> list[torch.dtype]: @@ -163,13 +171,19 @@ def get_attn_backend_cls( ) -> list[str]: from vllm_fl.attention.custom_attention import register_attention register_attention() - backend = _get_backend( - use_mla=False, - device_info=cls.device_info, - ) + device_capability = cls.get_device_capability() + + if selected_backend is None: + backend = _get_backend( + use_mla=False, + device_info=cls.device_info, + )[0] # get the highest priority backend + else: + backend = selected_backend + backend_class = backend.get_class() invalid_reasons = backend_class.validate_configuration( - device_capability=None, + device_capability=device_capability, **attn_selector_config._asdict(), ) reasons_str = ( @@ -271,11 +285,17 @@ def use_custom_allreduce(cls) -> bool: if cls.dist_backend == "flagcx": return False return True + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) @classmethod def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: try: import pynvml + pynvml.nvmlInit() """ query if the set of gpus are fully connected by nvlink (1 hop) """ diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index ce442adb..d3742c17 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -27,6 +27,7 @@ get_kv_transfer_group, has_kv_transfer_group, ) +from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.distributed.parallel_state import ( get_pcp_group, get_pp_group, @@ -52,18 +53,14 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager - +from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.core.sched.output import SchedulerOutput + from vllm_fl.worker.model_runner import ModelRunnerFL -from vllm_fl.worker.model_runner import ModelRunnerFL -from vllm_fl.ops.custom_ops import register_oot_ops - -import flag_gems class WorkerFL(WorkerBase): @@ -106,8 +103,11 @@ def __init__( else: self.profiler = None - register_oot_ops() - flag_gems.enable(record=False) #, unused=["index", "index_put_"]) + if "USE_FLAGGEMS" in os.environ and os.environ["USE_FLAGGEMS"] == "1": + from vllm_fl.ops.custom_ops import register_oot_ops + import flag_gems + register_oot_ops() + flag_gems.enable(record=False) #, unused=["index", "index_put_"]) # def sleep(self, level: int = 1) -> None: # TODO(lms): rewrite CuMemAllocator @@ -243,6 +243,7 @@ def init_device(self): num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 init_workspace_manager(self.device, num_ubatches) + from vllm_fl.worker.model_runner import ModelRunnerFL # Construct the model runner self.model_runner = ModelRunnerFL( self.vllm_config, self.device) @@ -421,6 +422,10 @@ def compile_or_warm_up_model(self) -> None: self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) ### NOTE(lms): can add gems kernel pretune here + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model() From cddd6f9356dc72315303abb7039d5d3ed61f2fa7 Mon Sep 17 00:00:00 2001 From: mslv Date: Wed, 7 Jan 2026 16:38:15 +0800 Subject: [PATCH 04/34] comment gems attention --- vllm_fl/ops/fused_moe/layer.py | 1 + vllm_fl/platform.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_fl/ops/fused_moe/layer.py b/vllm_fl/ops/fused_moe/layer.py index ab120ff1..ebf5caa0 100644 --- a/vllm_fl/ops/fused_moe/layer.py +++ b/vllm_fl/ops/fused_moe/layer.py @@ -232,3 +232,4 @@ def valid_grouping() -> bool: else: zero_expert_result = None return topk_weights, topk_ids, zero_expert_result + FusedMoE.select_experts = select_experts diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 36a5416c..dd23227e 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -169,8 +169,8 @@ def get_attn_backend_cls( selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", ) -> list[str]: - from vllm_fl.attention.custom_attention import register_attention - register_attention() + # from vllm_fl.attention.custom_attention import register_attention + # register_attention() device_capability = cls.get_device_capability() if selected_backend is None: From 2864b24c9c8315683920f0dd02c8523558e9ae6a Mon Sep 17 00:00:00 2001 From: mslv Date: Fri, 9 Jan 2026 16:18:26 +0800 Subject: [PATCH 05/34] add qwen3 next --- .../E=512,N=128,device_name=cuda.json | 147 ++ setup.py | 3 +- vllm_fl/__init__.py | 8 + vllm_fl/attention/attention.py | 89 +- vllm_fl/models/qwen3_next.py | 1400 +++++++++++++++++ vllm_fl/ops/fla/__init__.py | 7 + vllm_fl/ops/fla/chunk.py | 149 ++ vllm_fl/ops/fla/fused_recurrent.py | 138 ++ vllm_fl/ops/fused_moe/layer.py | 1 + vllm_fl/platform.py | 1 - vllm_fl/worker/model_runner.py | 11 +- 11 files changed, 1915 insertions(+), 39 deletions(-) create mode 100644 examples/fuse_moe_fuse/E=512,N=128,device_name=cuda.json create mode 100644 vllm_fl/models/qwen3_next.py create mode 100644 vllm_fl/ops/fla/__init__.py create mode 100644 vllm_fl/ops/fla/chunk.py create mode 100644 vllm_fl/ops/fla/fused_recurrent.py diff --git a/examples/fuse_moe_fuse/E=512,N=128,device_name=cuda.json b/examples/fuse_moe_fuse/E=512,N=128,device_name=cuda.json new file mode 100644 index 00000000..dc478aaa --- /dev/null +++ b/examples/fuse_moe_fuse/E=512,N=128,device_name=cuda.json @@ -0,0 +1,147 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} + diff --git a/setup.py b/setup.py index 6cc79948..18ba3858 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ def _read_requirements(filename: str) -> List[str]: python_requires=">=3.9", install_requires=get_requirements(), extras_require={}, - entry_points={'vllm.platform_plugins': ["fl = vllm_fl:register"]} + entry_points={'vllm.platform_plugins': ["fl = vllm_fl:register"], + 'vllm.general_plugins': ["fl_model = vllm_fl:register_model"]} ) diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index b185b552..8e46acda 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -5,3 +5,11 @@ def register(): return "vllm_fl.platform.PlatformFL" + +def register_model(): + """Register the FL model.""" + from vllm import ModelRegistry + from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM + ModelRegistry.register_model( + "Qwen3NextForCausalLM", + "vllm_fl.models.qwen3_next:Qwen3NextForCausalLM") \ No newline at end of file diff --git a/vllm_fl/attention/attention.py b/vllm_fl/attention/attention.py index e00a51b8..96e9ff83 100644 --- a/vllm_fl/attention/attention.py +++ b/vllm_fl/attention/attention.py @@ -11,31 +11,36 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) from vllm.attention.layer import Attention from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.backends.registry import ( - AttentionBackendEnum, - register_backend, -) -from vllm.config import VllmConfig, get_layers_from_vllm_config + +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_dcp_local_seq_lens, - get_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_dcp_local_seq_lens, + get_kv_cache_layout, +) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.platforms.interface import DeviceCapability from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash +# from vllm.attention.utils.fa_utils import flash_attn_varlen_func #reshape_and_cache_flash, +# from flag_gems import reshape_and_cache_flash logger = init_logger(__name__) @@ -45,9 +50,25 @@ class AttentionFLBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - @classmethod - def supports_head_size(cls, head_size: int) -> list[int]: - return head_size % 8 == 0 and head_size <= 256 + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + if ( + model_config + and model_config.is_hybrid + and ( + cache_config.mamba_ssm_cache_dtype == "float32" + or cache_config.mamba_cache_dtype == "float32" + ) + ): + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] + return [MultipleOf(16)] @staticmethod def get_name() -> str: @@ -65,9 +86,6 @@ def supports_attn_type(cls, attn_type: str) -> bool: def get_impl_cls() -> type["AttentionFLImpl"]: return AttentionFLImpl - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AttentionFLMetadata @staticmethod def get_builder_cls() -> type["AttentionFLMetadataBuilder"]: @@ -117,6 +135,10 @@ def get_kv_cache_stride_order( raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size % 8 == 0 and head_size <= 256 + @classmethod def supports_combination( cls, @@ -312,7 +334,6 @@ def build( # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - use_cascade = common_prefix_len > 0 use_cascade = common_prefix_len > 0 max_dcp_context_kv_len = 0 dcp_context_kv_lens = None @@ -353,11 +374,12 @@ def build( prefix_scheduler_metadata = None scheduler_metadata = None else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None scheduler_metadata = None + # cu_prefix_query_lens = None + # prefix_kv_lens = None + # suffix_kv_lens = None + # prefix_scheduler_metadata = None + # scheduler_metadata = None # For FA3 + full cudagraph if self.use_full_cuda_graph and scheduler_metadata is not None: @@ -435,7 +457,7 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.attn_type = attn_type - self.vllm_flash_attn_version = 2 #get_flash_attn_version() + self.vllm_flash_attn_version = 3 # 2 #get_flash_attn_version() # Cache the batch invariant result for use in forward passes self.batch_invariant_enabled = vllm_is_batch_invariant() @@ -729,7 +751,8 @@ def _forward_encoder_attention( descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads) + self.num_kv_heads, + ) # Call flash attention directly on Q, K, V tensors flash_attn_varlen_func( @@ -750,6 +773,7 @@ def _forward_encoder_attention( q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + # num_splits=0, ) return output @@ -874,10 +898,10 @@ def cascade_attention( q=query, k=key_cache, v=value_cache, - max_seqlen_q=num_tokens, cu_seqlens_q=cu_prefix_query_lens, - max_seqlen_k=common_prefix_len, seqused_k=prefix_kv_lens, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len, softmax_scale=softmax_scale, causal=False, window_size=sliding_window, @@ -889,6 +913,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + # num_splits=0, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -898,10 +923,10 @@ def cascade_attention( q=query, k=key_cache, v=value_cache, - max_seqlen_q=max_query_len, cu_seqlens_q=cu_query_lens, - max_seqlen_k=max_kv_len - common_prefix_len, seqused_k=suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, softmax_scale=softmax_scale, causal=True, window_size=sliding_window, @@ -913,8 +938,8 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + # num_splits=0, ) - ### TODO(lms): can specify triton version # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) \ No newline at end of file + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm_fl/models/qwen3_next.py b/vllm_fl/models/qwen3_next.py new file mode 100644 index 00000000..a76b9aa5 --- /dev/null +++ b/vllm_fl/models/qwen3_next.py @@ -0,0 +1,1400 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next model.""" + +from collections.abc import Iterable +from itertools import islice +import nvtx + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.layernorm import ( + GemmaRMSNorm as Qwen3NextRMSNorm, +) +from vllm.model_executor.layers.layernorm import RMSNormGated +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +from vllm_fl.ops.fla import ChunkGatedDeltaRuleOp, FusedRecurrentGatedDeltaRuleOp + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + expert_gate=self.shared_expert_gate, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + gate=self.gate, + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_expert is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + @property + def mamba_type(self) -> str: + return "gdn_attention" + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_qkvz, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvz", + ) + # ba_proj doesn't support blockwise fp8 quantization. + self.in_proj_ba = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_ba, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_ba", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + ) + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=current_platform.current_device(), + dtype=config.dtype, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp( + output_final_state = True, + use_qk_l2norm_in_kernel=True, + ) + + self.fused_recurrent_gated_delta_rule_multi_query = FusedRecurrentGatedDeltaRuleOp( + inplace_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + self.fused_recurrent_gated_delta_rule_remain_query = FusedRecurrentGatedDeltaRuleOp( + inplace_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + + def fix_query_key_value_ordering( + self, + mixed_qkvz, + mixed_ba, + ): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + ( + self.head_k_dim + + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) + * self.num_v_heads + // self.num_k_heads + ), + ) + new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads, + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], + # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), + (query, key), + ) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query.contiguous(), key.contiguous(), value.contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + num_tokens = hidden_states.size(0) + + # ============================================================ + # Part 1: Input Projection + # ============================================================ + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + query, key, value = map( + lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) + ) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # ============================================================ + # Part 2: Core Attention (Custom Op) + # ============================================================ + # Note: we should not use torch.empty here like other attention backends, + # see discussions in https://github.com/vllm-project/vllm/pull/28182 + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops.vllm.gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + z_shape_og = z.shape + # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """ + Core attention computation (called by custom op). + """ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][ + : attn_metadata.num_spec_decodes + ], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec_T, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[ + : attn_metadata.num_actual_tokens + ], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec + ) + + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2. Recurrent attention + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = self.fused_recurrent_gated_delta_rule_multi_query( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = self.chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + cu_seqlens=non_spec_query_start_loc, + ) + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype + ) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + self.fused_recurrent_gated_delta_rule_remain_query( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], + ssm_state_indices=non_spec_state_indices_tensor, + ) + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + elif spec_sequence_masks is not None: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + + +class Qwen3NextAttention(nn.Module): + def __init__( + self, + config: Qwen3NextConfig, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + self.attn_output_gate = getattr(config, "attn_output_gate", True) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=getattr(config, "qkv_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + max_position=config.max_position_embeddings, + rope_parameters=config.rope_parameters, + dual_chunk_attention_config=self.dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": self.dual_chunk_attention_config, + } + if self.dual_chunk_attention_config + else {}, + ) + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + output: torch.Tensor, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 + ) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim + ) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim + ) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + +class Qwen3NextDecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + layer_type: str, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.linear_attn", + ) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 + and (self.layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock( + vllm_config=vllm_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.dtype, + ), + ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.dtype, + ), + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) + else: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype) + 1 + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) + else: + assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( + f"shape must be the same {len(hidden_states.shape)}, " + f"{len(self.ffn_layer_scale.shape)}" + ) + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype) + 1 + ) + + return hidden_states, residual + + +@support_torch_compile +class Qwen3NextModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Qwen3NextConfig = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + + def get_layer(prefix: str): + return Qwen3NextDecoderLayer( + vllm_config, + layer_type=config.layer_types[extract_layer_index(prefix)], + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + if get_pp_group().is_last_rank: + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + logger.warning_once( + f"Parameter {name} not found in params_dict, skip loading" + ) + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class QwenNextMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( + layer.mlp, Qwen3NextSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + # Set MoE hyperparameters + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + +class Qwen3NextForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + QwenNextMixtureOfExperts, + IsHybrid, +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, ( + "Qwen3Next currently does not support prefix caching" + ) + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3NextModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # Set MoE hyperparameters + self.set_moe_parameters() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def gdn_attention_core( + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + layer_name: str, +) -> None: + """ + Custom op for the core attention computation. + Only handles the convolution + recurrent attention part. + Input/output projections are handled outside this op. + """ + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward_core( + mixed_qkv=mixed_qkv, + b=b, + a=a, + core_attn_out=core_attn_out, + ) + + +def gdn_attention_core_fake( + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + layer_name: str, +) -> None: + """Fake implementation for torch.compile.""" + return + + +direct_register_custom_op( + op_name="gdn_attention_core", + op_func=gdn_attention_core, + mutates_args=["core_attn_out"], + fake_impl=gdn_attention_core_fake, +) + + +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + # compute beta_output = sigmoid(b) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store( + beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask + ) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused computation of g and beta for Gated Delta Net. + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + beta_output = b.sigmoid() + TODO maybe use torch.compile to replace this triton kernel + """ + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1, + ) + return g, beta_output diff --git a/vllm_fl/ops/fla/__init__.py b/vllm_fl/ops/fla/__init__.py new file mode 100644 index 00000000..57bf8010 --- /dev/null +++ b/vllm_fl/ops/fla/__init__.py @@ -0,0 +1,7 @@ +from .chunk import ChunkGatedDeltaRuleOp +from .fused_recurrent import FusedRecurrentGatedDeltaRuleOp + +__all__ = [ + "ChunkGatedDeltaRuleOp", + "FusedRecurrentGatedDeltaRuleOp", +] diff --git a/vllm_fl/ops/fla/chunk.py b/vllm_fl/ops/fla/chunk.py new file mode 100644 index 00000000..00f588fe --- /dev/null +++ b/vllm_fl/ops/fla/chunk.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +import warnings +from typing import Optional, Union +import os +import torch + +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd +from vllm.model_executor.layers.fla.ops.utils import input_guard +from flag_gems.fused.FLA import chunk_gated_delta_rule_fwd + + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@CustomOp.register("chunk_gated_delta_rule") +class ChunkGatedDeltaRuleOp(CustomOp): + def __init__( + self, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + ) -> None: + r""" + Args: + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + """ + super().__init__() + self.output_final_state = output_final_state + self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + self.suppress_level = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + def forward_native( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + cu_seqlens: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]``. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, ( + "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + ) + assert len(beta.shape) == 3, ( + "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + ) + + if q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + self.output_final_state, + cu_seqlens, + self.use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/vllm_fl/ops/fla/fused_recurrent.py b/vllm_fl/ops/fla/fused_recurrent.py new file mode 100644 index 00000000..ff73ddcf --- /dev/null +++ b/vllm_fl/ops/fla/fused_recurrent.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +import warnings +from typing import Optional, Union +import os +import torch + +from vllm.model_executor.custom_op import CustomOp +from flag_gems.fused.FLA import fused_recurrent_gated_delta_rule_fwd + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +@CustomOp.register("fused_recurrent_gated_delta_rule") +class FusedRecurrentGatedDeltaRuleOp(CustomOp): + def __init__( + self, + inplace_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + ) -> None: + r""" + Args: + inplace_final_state (Optional[bool]): + Whether to inplace the final state of shape `[N, H, K, V]`. Default: `False`. + """ + super().__init__() + self.inplace_final_state = inplace_final_state + self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + + def forward_native( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + r""" + `Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + self.inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + self.use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/vllm_fl/ops/fused_moe/layer.py b/vllm_fl/ops/fused_moe/layer.py index ebf5caa0..7f5960cc 100644 --- a/vllm_fl/ops/fused_moe/layer.py +++ b/vllm_fl/ops/fused_moe/layer.py @@ -69,6 +69,7 @@ def forward_oot( class FusedMoEFL(FusedMoE): + def forward_oot(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index dd23227e..f1436a49 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -47,7 +47,6 @@ def _get_backend( return [AttentionBackendEnum.TRITON_ATTN] #"vllm_fl.attention.attention.AttentionFLBackend" return [AttentionBackendEnum.FLASH_ATTN] - class PlatformFL(Platform): _enum = PlatformEnum.OOT diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index 46b654f0..febedc1d 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -23,7 +23,8 @@ import vllm.envs as envs from vllm.attention.backends.abstract import( - AttentionBackend, + AttentionBackend, + AttentionMetadata, AttentionType, MultipleOf) from vllm.attention.layer import Attention, MLAAttention @@ -45,7 +46,6 @@ get_dcp_group, get_pp_group, get_tp_group, - graph_capture, is_global_first_rank, prepare_communication_buffer_for_model, GraphCaptureContext @@ -213,7 +213,6 @@ def graph_capture(device: torch.device): from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm_fl.compilation.graph import GraphWrapper -from vllm_fl.attention.attention import AttentionMetadata logger = init_logger(__name__) @@ -1896,7 +1895,7 @@ def _compute_cascade_attn_prefix_len( use_alibi=self.use_alibi, use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, - num_sms=self.num_sms, + num_sms=1, dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0 @@ -3723,6 +3722,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: if self.eplb_state.is_async: self.eplb_state.start_async_loop(rank_mapping=rank_mapping) + # print(f"{self.vllm_config.compilation_config.mode=}") + if ( self.vllm_config.compilation_config.mode == CompilationMode.STOCK_TORCH_COMPILE @@ -3738,7 +3739,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: # wrap the model with full cudagraph wrapper if needed. cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if self.compilation_config.cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: + if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: self.model = GraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) From 80a74875c65dfb7f03acbbbee8605e18eb321e09 Mon Sep 17 00:00:00 2001 From: yxa Date: Fri, 9 Jan 2026 10:41:24 +0000 Subject: [PATCH 06/34] Add a dispatch mechanism. --- vllm_fl/dispatch/README.md | 354 ++++++++++++ vllm_fl/dispatch/__init__.py | 131 +++++ vllm_fl/dispatch/backends/__init__.py | 11 + vllm_fl/dispatch/backends/base.py | 51 ++ .../dispatch/backends/flaggems/__init__.py | 9 + .../dispatch/backends/flaggems/flaggems.py | 117 ++++ .../backends/flaggems/impl/__init__.py | 15 + .../backends/flaggems/impl/activation.py | 26 + .../backends/flaggems/impl/normalization.py | 34 ++ .../dispatch/backends/flaggems/impl/rotary.py | 46 ++ .../backends/flaggems/register_ops.py | 69 +++ .../dispatch/backends/reference/__init__.py | 9 + .../backends/reference/impl/__init__.py | 15 + .../backends/reference/impl/activation.py | 25 + .../backends/reference/impl/normalization.py | 42 ++ .../backends/reference/impl/rotary.py | 90 ++++ .../dispatch/backends/reference/reference.py | 119 ++++ .../backends/reference/register_ops.py | 69 +++ vllm_fl/dispatch/backends/vendor/__init__.py | 7 + vllm_fl/dispatch/builtin_ops.py | 58 ++ vllm_fl/dispatch/discovery.py | 241 +++++++++ vllm_fl/dispatch/logger_manager.py | 70 +++ vllm_fl/dispatch/manager.py | 507 ++++++++++++++++++ vllm_fl/dispatch/ops.py | 127 +++++ vllm_fl/dispatch/policy.py | 405 ++++++++++++++ vllm_fl/dispatch/registry.py | 120 +++++ vllm_fl/dispatch/types.py | 112 ++++ vllm_fl/ops/activation.py | 9 +- vllm_fl/ops/layernorm.py | 11 +- vllm_fl/ops/rotary_embedding.py | 29 +- 30 files changed, 2907 insertions(+), 21 deletions(-) create mode 100644 vllm_fl/dispatch/README.md create mode 100644 vllm_fl/dispatch/__init__.py create mode 100644 vllm_fl/dispatch/backends/__init__.py create mode 100644 vllm_fl/dispatch/backends/base.py create mode 100644 vllm_fl/dispatch/backends/flaggems/__init__.py create mode 100644 vllm_fl/dispatch/backends/flaggems/flaggems.py create mode 100644 vllm_fl/dispatch/backends/flaggems/impl/__init__.py create mode 100644 vllm_fl/dispatch/backends/flaggems/impl/activation.py create mode 100644 vllm_fl/dispatch/backends/flaggems/impl/normalization.py create mode 100644 vllm_fl/dispatch/backends/flaggems/impl/rotary.py create mode 100644 vllm_fl/dispatch/backends/flaggems/register_ops.py create mode 100644 vllm_fl/dispatch/backends/reference/__init__.py create mode 100644 vllm_fl/dispatch/backends/reference/impl/__init__.py create mode 100644 vllm_fl/dispatch/backends/reference/impl/activation.py create mode 100644 vllm_fl/dispatch/backends/reference/impl/normalization.py create mode 100644 vllm_fl/dispatch/backends/reference/impl/rotary.py create mode 100644 vllm_fl/dispatch/backends/reference/reference.py create mode 100644 vllm_fl/dispatch/backends/reference/register_ops.py create mode 100644 vllm_fl/dispatch/backends/vendor/__init__.py create mode 100644 vllm_fl/dispatch/builtin_ops.py create mode 100644 vllm_fl/dispatch/discovery.py create mode 100644 vllm_fl/dispatch/logger_manager.py create mode 100644 vllm_fl/dispatch/manager.py create mode 100644 vllm_fl/dispatch/ops.py create mode 100644 vllm_fl/dispatch/policy.py create mode 100644 vllm_fl/dispatch/registry.py create mode 100644 vllm_fl/dispatch/types.py diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md new file mode 100644 index 00000000..32eb5849 --- /dev/null +++ b/vllm_fl/dispatch/README.md @@ -0,0 +1,354 @@ +# Dispatch Mechanism + +This directory implements the operator dispatch mechanism for vllm-plugin-FL, providing a flexible operator dispatch system that selects between different backend implementations (FlagGems, PyTorch, etc.) based on availability and policy configuration. + +## Directory Structure + +``` +dispatch/ +├── __init__.py # Module entry point, exports public API +├── types.py # Core type definitions (OpImpl, BackendImplKind) +├── registry.py # Thread-safe operator registry +├── policy.py # Selection policy management +├── manager.py # Core dispatch manager +├── builtin_ops.py # Built-in operator registration (calls backend register_ops) +├── ops.py # Backend base interface (VLLMFLBackendBase) +├── discovery.py # Plugin discovery mechanism +├── logger_manager.py # Centralized logging configuration +└── backends/ # Backend implementations + ├── __init__.py + ├── base.py # Backend abstract base class + ├── flaggems/ # FlagGems backend + │ ├── __init__.py + │ ├── flaggems.py # Backend class + │ ├── register_ops.py # Operator registration + │ └── impl/ # Operator implementations + │ ├── __init__.py + │ ├── activation.py + │ ├── normalization.py + │ └── rotary.py + ├── reference/ # Reference backend (PyTorch) + │ ├── __init__.py + │ ├── reference.py # Backend class + │ ├── register_ops.py # Operator registration + │ └── impl/ # Operator implementations + │ ├── __init__.py + │ ├── activation.py + │ ├── normalization.py + │ └── rotary.py + └── vendor/ # Vendor-specific backends + └── __init__.py # (Add CUDA, etc. as needed) +``` + +## Core Concepts + +### 1. Backend Implementation Kind (BackendImplKind) + +- **DEFAULT**: Default implementation (FlagGems), priority 150 +- **REFERENCE**: Reference implementation (PyTorch native), priority 50 +- **VENDOR**: Vendor-specific implementation (e.g., CUDA), requires vendor name + +### 2. Operator Implementation (OpImpl) + +Each operator implementation contains the following attributes: +- `op_name`: Operator name (e.g., "silu_and_mul", "rmsnorm") +- `impl_id`: Unique implementation identifier (e.g., "default.flaggems") +- `kind`: Implementation type +- `fn`: Actual implementation function +- `vendor`: Vendor name (required for VENDOR type) +- `priority`: Selection priority (higher value = preferred) + +### 3. Selection Policy (SelectionPolicy) + +Policy controls operator implementation selection behavior: +- `prefer`: Preferred implementation type +- `strict`: Strict mode, whether to raise error when primary implementation fails +- `per_op_order`: Custom selection order for each operator +- `deny_vendors`: List of denied vendors +- `allow_vendors`: Whitelist of allowed vendors + +## Quick Start + +### Basic Usage + +```python +from vllm_fl.dispatch import call_op, resolve_op + +# Method 1: Call operator directly +result = call_op("silu_and_mul", x) + +# Method 2: Resolve first, then call +fn = resolve_op("rmsnorm") +result = fn(x, residual, weight, epsilon) +``` + +### Using the Manager + +```python +from vllm_fl.dispatch import get_default_manager + +manager = get_default_manager() + +# Resolve operator +fn = manager.resolve("rotary_embedding") +result = fn(query, key, cos, sin, position_ids) + +# Or call directly +result = manager.call("silu_and_mul", x) +``` + +## Environment Variables + +| Variable | Description | Example | +|----------|-------------|---------| +| `VLLM_FL_PREFER` | Preferred backend | `flaggems`, `vendor`, `reference` | +| `VLLM_FL_STRICT` | Strict mode | `1` or `0` | +| `VLLM_FL_DENY_VENDORS` | Denied vendors list | `vendor1,vendor2` | +| `VLLM_FL_ALLOW_VENDORS` | Allowed vendors whitelist | `vendor1,vendor2` | +| `VLLM_FL_PER_OP` | Per-operator selection order | `op1=a\|b\|c;op2=x\|y` | +| `VLLM_FL_PLUGIN_MODULES` | Plugin modules to load | `my_plugin,another_plugin` | +| `VLLM_FL_LOG_LEVEL` | Log level | `DEBUG`, `INFO`, `WARNING`, `ERROR` | + +### Examples + +```bash +# Prefer FlagGems implementation +export VLLM_FL_PREFER=flaggems + +# Enable strict mode (auto-fallback on failure) +export VLLM_FL_STRICT=1 + +# Deny specific vendors +export VLLM_FL_DENY_VENDORS=vendor_a,vendor_b + +# Specify selection order for specific operator +export VLLM_FL_PER_OP="rmsnorm=vendor|flaggems|reference" + +# Load external plugins +export VLLM_FL_PLUGIN_MODULES=my_custom_backend + +# Set log level +export VLLM_FL_LOG_LEVEL=DEBUG +``` + +## Policy Context Management + +Supports temporary policy override in code: + +```python +from vllm_fl.dispatch import ( + policy_context, + with_strict_mode, + with_preference, + with_allowed_vendors, + with_denied_vendors, + SelectionPolicy, +) + +# Temporarily enable strict mode +with with_strict_mode(): + result = call_op("silu_and_mul", x) + +# Temporarily switch preferred backend +with with_preference("reference"): + result = call_op("rmsnorm", x, residual, weight, epsilon) + +# Temporarily restrict allowed vendors +with with_allowed_vendors("vendor_a"): + result = call_op("rotary_embedding", query, key, cos, sin, position_ids) + +# Use custom policy +custom_policy = SelectionPolicy.from_dict( + prefer="flaggems", + strict=True, + deny_vendors={"vendor_x"}, +) +with policy_context(custom_policy): + result = call_op("silu_and_mul", x) +``` + +## Supported Operators + +Currently supported operators: + +| Operator | Description | FlagGems | Reference | +|----------|-------------|----------|-----------| +| `silu_and_mul` | SiLU activation + element-wise multiplication | ✓ | ✓ | +| `rmsnorm` | RMS normalization | ✓ | ✓ | +| `rotary_embedding` | Rotary position embedding | ✓ | ✓ | + +## Selection Process + +1. **Cache Check**: Check if dispatch cache hits +2. **Get Implementations**: Retrieve all registered implementations from registry +3. **Vendor Filtering**: Filter by policy's allow/deny lists +4. **Availability Check**: Call `is_available()` to check if implementation is available +5. **Priority Sorting**: Select best implementation based on per-op order or default order +6. **Cache Result**: Cache selection result to speed up subsequent calls + +## Fallback Mechanism + +When `VLLM_FL_STRICT=1`, if the primary implementation fails, the system automatically tries other available implementations: + +``` +Op 'rmsnorm' using 'default.flaggems' (kind=flaggems, vendor=None) +[WARNING] Implementation 'default.flaggems' failed for op 'rmsnorm': ... +Op 'rmsnorm' fallback to 'reference.torch' (kind=reference, vendor=None) +``` + +## Extending with New Operators + +When adding a new operator (e.g., `layernorm`), modify the following files: + +| File | Changes | +|------|---------| +| `backends/flaggems/impl/normalization.py` | Add FlagGems implementation | +| `backends/flaggems/flaggems.py` | Add method to backend class | +| `backends/flaggems/register_ops.py` | Register OpImpl | +| `backends/reference/impl/normalization.py` | Add PyTorch implementation | +| `backends/reference/reference.py` | Add method to backend class | +| `backends/reference/register_ops.py` | Register OpImpl | +| `ops.py` | Add abstract method declaration | + +## Extending with New Backends + +### 1. Create Backend Directory Structure + +``` +backends/my_backend/ +├── __init__.py +├── my_backend.py # Backend class +├── register_ops.py # Operator registration +└── impl/ # Operator implementations + ├── __init__.py + ├── activation.py + └── ... +``` + +### 2. Implement Backend Class + +```python +# backends/my_backend/my_backend.py +from ..base import Backend + +class MyBackend(Backend): + @property + def name(self) -> str: + return "my_backend" + + def is_available(self) -> bool: + try: + import my_library + return True + except ImportError: + return False + + def silu_and_mul(self, x): + from .impl.activation import silu_and_mul_my_backend + return silu_and_mul_my_backend(x) +``` + +### 3. Create Registration Module + +```python +# backends/my_backend/register_ops.py +from ...types import OpImpl, BackendImplKind + +def register_builtins(registry) -> None: + from .my_backend import MyBackend + + backend = MyBackend() + is_avail = backend.is_available + + impls = [ + OpImpl( + op_name="silu_and_mul", + impl_id="default.my_backend", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + priority=100, + ), + ] + + registry.register_many(impls) +``` + +### 4. Update builtin_ops.py + +```python +# In builtin_ops.py, add: +try: + from .backends.my_backend.register_ops import register_builtins as register_my_backend + register_my_backend(registry) +except Exception as e: + logger.warning(f"Failed to register MyBackend operators: {e}") +``` + +## Plugin Discovery + +External plugins can register operators via: + +### 1. Entry Points (Recommended) + +```python +# In your plugin's setup.py or pyproject.toml +[project.entry-points."vllm_fl.plugin"] +my_plugin = "my_plugin_package:register" +``` + +```python +# my_plugin_package/__init__.py +def register(registry): + # Register your operators + registry.register_impl(OpImpl(...)) +``` + +### 2. Environment Variable + +```bash +export VLLM_FL_PLUGIN_MODULES=my_plugin_module +``` + +```python +# my_plugin_module.py +def vllm_fl_register(registry): + # Register your operators + pass +``` + +## Multi-Process Safety + +OpManager supports multi-process environments: +- Uses `os.register_at_fork()` to automatically reset state after fork +- PID detection ensures independent initialization per process +- Thread-safe registry and cache operations + +## API Reference + +### Convenience Functions + +- `call_op(op_name, *args, **kwargs)`: Call an operator +- `resolve_op(op_name)`: Resolve operator implementation + +### Policy Management + +- `get_policy()`: Get current policy +- `set_global_policy(policy)`: Set global policy +- `reset_global_policy()`: Reset to environment variable defaults +- `policy_context(policy)`: Temporary policy context + +### Manager + +- `get_default_manager()`: Get default manager instance +- `reset_default_manager()`: Reset default manager + +### Plugin Discovery + +- `discover_plugins(registry)`: Discover and load plugins +- `get_discovered_plugins()`: Get list of discovered plugins +- `clear_discovered_plugins()`: Clear discovered plugins list + +### Logging + +- `get_logger(name)`: Get logger instance +- `set_log_level(level, name)`: Set log level diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py new file mode 100644 index 00000000..5724a90a --- /dev/null +++ b/vllm_fl/dispatch/__init__.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Dispatch mechanism for vllm-plugin-FL. + +This module provides a flexible operator dispatch system that allows +selecting between different backend implementations (FlagGems, PyTorch, etc.) +based on availability and policy configuration. + +Usage: + from vllm_fl.dispatch import get_default_manager, call_op + + # Call an operator through the dispatch system + result = call_op("silu_and_mul", x) + + # Or use the manager directly + manager = get_default_manager() + fn = manager.resolve("rmsnorm") + result = fn(x, residual, weight, epsilon) + +Environment Variables: + VLLM_FL_PREFER: Preferred backend ("flaggems", "vendor", "reference") + VLLM_FL_STRICT: Enable strict mode ("1" or "0") + VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors + VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors + VLLM_FL_PER_OP: Per-operator order (format: op1=a|b|c;op2=x|y) + VLLM_FL_PLUGIN_MODULES: Comma-separated list of plugin modules to load + VLLM_FL_LOG_LEVEL: Log level for dispatch module (DEBUG, INFO, WARNING, ERROR) + VLLM_FL_DISPATCH_DEBUG: Enable debug printing ("1" or "0", default: "0") + When enabled, prints: + - Detailed list of registered operators and implementations at initialization + - Selected backend for each operator call +""" + +from .types import OpImpl, BackendImplKind, match_token +from .registry import OpRegistry, OpRegistrySnapshot +from .policy import ( + SelectionPolicy, + PolicyManager, + get_policy, + set_global_policy, + reset_global_policy, + policy_context, + with_strict_mode, + with_preference, + with_allowed_vendors, + with_denied_vendors, + PREFER_DEFAULT, + PREFER_VENDOR, + PREFER_REFERENCE, +) +from .manager import OpManager, get_default_manager, reset_default_manager +from .ops import VLLMFLBackendBase +from .discovery import ( + discover_plugins, + get_discovered_plugins, + clear_discovered_plugins, + PLUGIN_GROUP, + PLUGIN_MODULES_ENV, +) +from .logger_manager import get_logger, set_log_level + + +def call_op(op_name: str, *args, **kwargs): + """ + Convenience function to call an operator through the default manager. + + Args: + op_name: Name of the operator + *args, **kwargs: Arguments passed to the operator + + Returns: + Result from the operator implementation + """ + return get_default_manager().call(op_name, *args, **kwargs) + + +def resolve_op(op_name: str): + """ + Convenience function to resolve an operator through the default manager. + + Args: + op_name: Name of the operator + + Returns: + Callable implementation function + """ + return get_default_manager().resolve(op_name) + + +__all__ = [ + # Types + "OpImpl", + "BackendImplKind", + "match_token", + # Registry + "OpRegistry", + "OpRegistrySnapshot", + # Policy + "SelectionPolicy", + "PolicyManager", + "get_policy", + "set_global_policy", + "reset_global_policy", + "policy_context", + "with_strict_mode", + "with_preference", + "with_allowed_vendors", + "with_denied_vendors", + "PREFER_DEFAULT", + "PREFER_VENDOR", + "PREFER_REFERENCE", + # Manager + "OpManager", + "get_default_manager", + "reset_default_manager", + # Backend base + "VLLMFLBackendBase", + # Plugin discovery + "discover_plugins", + "get_discovered_plugins", + "clear_discovered_plugins", + "PLUGIN_GROUP", + "PLUGIN_MODULES_ENV", + # Logging + "get_logger", + "set_log_level", + # Convenience functions + "call_op", + "resolve_op", +] diff --git a/vllm_fl/dispatch/backends/__init__.py b/vllm_fl/dispatch/backends/__init__.py new file mode 100644 index 00000000..14a7ea05 --- /dev/null +++ b/vllm_fl/dispatch/backends/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Backend implementations for vllm-plugin-FL dispatch. +""" + +from .base import Backend +from .flaggems import FlagGemsBackend +from .reference import ReferenceBackend + +__all__ = ["Backend", "FlagGemsBackend", "ReferenceBackend"] diff --git a/vllm_fl/dispatch/backends/base.py b/vllm_fl/dispatch/backends/base.py new file mode 100644 index 00000000..0136d6a6 --- /dev/null +++ b/vllm_fl/dispatch/backends/base.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Base backend class for operator implementations. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + + +class Backend(ABC): + """ + Abstract base class for operator backends. + + Each backend provides implementations for a set of operators. + Backends should implement is_available() to indicate whether + the backend can be used in the current environment. + """ + + @abstractmethod + def is_available(self) -> bool: + """ + Check if this backend is available in the current environment. + + Returns: + True if the backend can be used, False otherwise. + """ + pass + + @property + @abstractmethod + def name(self) -> str: + """ + Get the name of this backend. + + Returns: + Backend name string. + """ + pass + + @property + def vendor(self) -> Optional[str]: + """ + Get the vendor name for this backend (if applicable). + + Returns: + Vendor name string, or None for non-vendor backends. + """ + return None diff --git a/vllm_fl/dispatch/backends/flaggems/__init__.py b/vllm_fl/dispatch/backends/flaggems/__init__.py new file mode 100644 index 00000000..f5bd4506 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems backend for vllm-plugin-FL dispatch. +""" + +from .flaggems import FlagGemsBackend + +__all__ = ["FlagGemsBackend"] diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py new file mode 100644 index 00000000..05c3f2c0 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems backend implementation. + +This backend provides operator implementations using the FlagGems library. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +from ..base import Backend + + +class FlagGemsBackend(Backend): + """ + FlagGems backend for operator implementations. + + This backend uses the flag_gems library to provide high-performance + operator implementations. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "flaggems" + + def is_available(self) -> bool: + """Check if FlagGems is available.""" + if FlagGemsBackend._available is None: + try: + import flag_gems + + FlagGemsBackend._available = True + except ImportError: + FlagGemsBackend._available = False + return FlagGemsBackend._available + + # ==================== Operator Implementations ==================== + + def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + from .impl.activation import silu_and_mul_flaggems + + return silu_and_mul_flaggems(x) + + def rmsnorm( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization. + + Args: + x: Input tensor + residual: Optional residual tensor + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + from .impl.normalization import rmsnorm_flaggems + + return rmsnorm_flaggems(x, residual, weight, epsilon) + + def rotary_embedding( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + from .impl.rotary import rotary_embedding_flaggems + + return rotary_embedding_flaggems( + query, + key, + cos, + sin, + position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/__init__.py b/vllm_fl/dispatch/backends/flaggems/impl/__init__.py new file mode 100644 index 00000000..96b1c704 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/impl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems operator implementations. +""" + +from .activation import silu_and_mul_flaggems +from .normalization import rmsnorm_flaggems +from .rotary import rotary_embedding_flaggems + +__all__ = [ + "silu_and_mul_flaggems", + "rmsnorm_flaggems", + "rotary_embedding_flaggems", +] diff --git a/vllm_fl/dispatch/backends/flaggems/impl/activation.py b/vllm_fl/dispatch/backends/flaggems/impl/activation.py new file mode 100644 index 00000000..a08e1933 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/impl/activation.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems activation operator implementations. +""" + +from __future__ import annotations + +import torch + + +def silu_and_mul_flaggems(x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication using FlagGems. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + from flag_gems.modules.activation import gems_silu_and_mul + + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + return gems_silu_and_mul(x1, x2) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py new file mode 100644 index 00000000..afc32c14 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems normalization operator implementations. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + + +def rmsnorm_flaggems( + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using FlagGems. + + Args: + x: Input tensor + residual: Optional residual tensor + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + from flag_gems.modules.normalization import gems_rms_forward + + return gems_rms_forward(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py new file mode 100644 index 00000000..574b5227 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems rotary embedding operator implementations. +""" + +from __future__ import annotations + +import torch + + +def rotary_embedding_flaggems( + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using FlagGems. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + from flag_gems.modules.rotary_embedding import gems_rope_forward + + return gems_rope_forward( + query, + key, + cos, + sin, + position_ids=position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py new file mode 100644 index 00000000..2c524a05 --- /dev/null +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +FlagGems backend operator registrations. + +This module registers all DEFAULT (FlagGems) implementations. +""" + +from __future__ import annotations + +import functools + +from ...types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all FlagGems (DEFAULT) operator implementations. + + Args: + registry: Registry to register into + """ + from .flaggems import FlagGemsBackend + + backend = FlagGemsBackend() + is_avail = backend.is_available + + impls = [ + # Activation + OpImpl( + op_name="silu_and_mul", + impl_id="default.flaggems", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + vendor=None, + priority=150, + ), + # Normalization + OpImpl( + op_name="rmsnorm", + impl_id="default.flaggems", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.rmsnorm, is_avail), + vendor=None, + priority=150, + ), + # Rotary Embedding + OpImpl( + op_name="rotary_embedding", + impl_id="default.flaggems", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.rotary_embedding, is_avail), + vendor=None, + priority=150, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/reference/__init__.py b/vllm_fl/dispatch/backends/reference/__init__.py new file mode 100644 index 00000000..cd9d0870 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference backend for vllm-plugin-FL dispatch. +""" + +from .reference import ReferenceBackend + +__all__ = ["ReferenceBackend"] diff --git a/vllm_fl/dispatch/backends/reference/impl/__init__.py b/vllm_fl/dispatch/backends/reference/impl/__init__.py new file mode 100644 index 00000000..3a23ba04 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/impl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference operator implementations using PyTorch. +""" + +from .activation import silu_and_mul_torch +from .normalization import rmsnorm_torch +from .rotary import rotary_embedding_torch + +__all__ = [ + "silu_and_mul_torch", + "rmsnorm_torch", + "rotary_embedding_torch", +] diff --git a/vllm_fl/dispatch/backends/reference/impl/activation.py b/vllm_fl/dispatch/backends/reference/impl/activation.py new file mode 100644 index 00000000..e7d41c43 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/impl/activation.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference activation operator implementations using PyTorch. +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def silu_and_mul_torch(x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication using PyTorch. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + return F.silu(x1) * x2 diff --git a/vllm_fl/dispatch/backends/reference/impl/normalization.py b/vllm_fl/dispatch/backends/reference/impl/normalization.py new file mode 100644 index 00000000..26c423e5 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/impl/normalization.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference normalization operator implementations using PyTorch. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + + +def rmsnorm_torch( + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using PyTorch. + + Args: + x: Input tensor + residual: Optional residual tensor + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + if residual is not None: + x = x + residual + residual = x + + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + output = weight * x + + if residual is not None: + return output, residual + return output diff --git a/vllm_fl/dispatch/backends/reference/impl/rotary.py b/vllm_fl/dispatch/backends/reference/impl/rotary.py new file mode 100644 index 00000000..aeb5c6d2 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/impl/rotary.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference rotary embedding operator implementations using PyTorch. +""" + +from __future__ import annotations + +import torch + + +def rotary_embedding_torch( + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using PyTorch. + + Args: + query: Query tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] + key: Key tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] + cos: Cosine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 + sin: Sine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 + position_ids: Position indices [batch, seq_len] or [seq_len] + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place (ignored in reference impl) + + Returns: + Tuple of (embedded_query, embedded_key) + """ + # Get cos/sin for the positions + # position_ids can be [batch, seq_len] or [seq_len] + if position_ids.dim() == 1: + # [seq_len] -> [seq_len, rotary_dim] + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + else: + # [batch, seq_len] -> [batch, seq_len, rotary_dim] + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + + # Expand dimensions to match query/key shape + # query/key: [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] + if query.dim() == 4: + # [batch, num_heads, seq_len, head_dim] + # cos_selected: [batch, seq_len, rotary_dim] -> [batch, 1, seq_len, rotary_dim] + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + elif query.dim() == 3: + # [seq_len, num_heads, head_dim] + # cos_selected: [seq_len, rotary_dim] -> [seq_len, 1, rotary_dim] + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + + # Check if we need to repeat cos/sin to match head_dim + rotary_dim = cos_selected.shape[-1] + head_dim = query.shape[-1] + + if rotary_dim != head_dim: + # cos/sin only covers half of head_dim, need to repeat + # This handles the case where rotary is only applied to part of the dimensions + cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) + sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + if rotary_interleaved: + # Interleaved rotary + def rotate_interleaved(x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + q_embed = (query * cos_selected) + (rotate_interleaved(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_interleaved(key) * sin_selected) + else: + # Standard rotary (neox style) + q_embed = (query * cos_selected) + (rotate_half(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_half(key) * sin_selected) + + return q_embed, k_embed diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py new file mode 100644 index 00000000..e32455f2 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference backend implementation using PyTorch. + +This backend provides reference operator implementations using native PyTorch +operations. These implementations are always available when PyTorch is installed +and serve as fallback implementations. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +from ..base import Backend + + +class ReferenceBackend(Backend): + """ + Reference backend for operator implementations. + + This backend uses native PyTorch operations to provide reference + implementations that are always available as fallbacks. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "reference" + + def is_available(self) -> bool: + """Check if PyTorch is available.""" + if ReferenceBackend._available is None: + try: + import torch + + ReferenceBackend._available = True + except ImportError: + ReferenceBackend._available = False + return ReferenceBackend._available + + # ==================== Operator Implementations ==================== + + def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + from .impl.activation import silu_and_mul_torch + + return silu_and_mul_torch(x) + + def rmsnorm( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization. + + Args: + x: Input tensor + residual: Optional residual tensor + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + from .impl.normalization import rmsnorm_torch + + return rmsnorm_torch(x, residual, weight, epsilon) + + def rotary_embedding( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place (ignored in reference impl) + + Returns: + Tuple of (embedded_query, embedded_key) + """ + from .impl.rotary import rotary_embedding_torch + + return rotary_embedding_torch( + query, + key, + cos, + sin, + position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) diff --git a/vllm_fl/dispatch/backends/reference/register_ops.py b/vllm_fl/dispatch/backends/reference/register_ops.py new file mode 100644 index 00000000..59a482f9 --- /dev/null +++ b/vllm_fl/dispatch/backends/reference/register_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Reference backend operator registrations. + +This module registers all REFERENCE (PyTorch) implementations. +""" + +from __future__ import annotations + +import functools + +from ...types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all PyTorch (REFERENCE) operator implementations. + + Args: + registry: Registry to register into + """ + from .reference import ReferenceBackend + + backend = ReferenceBackend() + is_avail = backend.is_available + + impls = [ + # Activation + OpImpl( + op_name="silu_and_mul", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + vendor=None, + priority=50, + ), + # Normalization + OpImpl( + op_name="rmsnorm", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.rmsnorm, is_avail), + vendor=None, + priority=50, + ), + # Rotary Embedding + OpImpl( + op_name="rotary_embedding", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.rotary_embedding, is_avail), + vendor=None, + priority=50, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/vendor/__init__.py b/vllm_fl/dispatch/backends/vendor/__init__.py new file mode 100644 index 00000000..8169d3ec --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Vendor backends for vllm-plugin-FL dispatch. + +This package contains vendor-specific backend implementations. +""" diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py new file mode 100644 index 00000000..69fe810e --- /dev/null +++ b/vllm_fl/dispatch/builtin_ops.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Built-in operator implementations registration. + +This module registers DEFAULT (FlagGems) and REFERENCE (PyTorch) implementations +for all supported operators by calling register_builtins from each backend. +""" + +from __future__ import annotations + +from .registry import OpRegistry +from .logger_manager import get_logger + +logger = get_logger() + + +def register_builtins(registry: OpRegistry) -> None: + """ + Register all built-in operator implementations. + + This function registers: + - DEFAULT implementations (FlagGems) + - REFERENCE implementations (PyTorch) + - VENDOR implementations (if available) + + Args: + registry: Registry to register into + """ + # Register FlagGems (DEFAULT) implementations + try: + from .backends.flaggems.register_ops import register_builtins as register_flaggems + + register_flaggems(registry) + logger.debug("Registered FlagGems operators") + except Exception as e: + logger.warning(f"Failed to register FlagGems operators: {e}") + + # Register PyTorch (REFERENCE) implementations + try: + from .backends.reference.register_ops import register_builtins as register_reference + + register_reference(registry) + logger.debug("Registered Reference operators") + except Exception as e: + logger.warning(f"Failed to register Reference operators: {e}") + + # Register VENDOR implementations (if available) + # Add vendor backends here as they become available + # Example: + # try: + # from .backends.vendor.cuda.register_ops import register_builtins as register_cuda + # register_cuda(registry) + # logger.debug("Registered CUDA operators") + # except Exception as e: + # # CUDA may not be available, this is expected + # logger.debug(f"CUDA operators not available: {e}") + # pass diff --git a/vllm_fl/dispatch/discovery.py b/vllm_fl/dispatch/discovery.py new file mode 100644 index 00000000..1155f92d --- /dev/null +++ b/vllm_fl/dispatch/discovery.py @@ -0,0 +1,241 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Plugin discovery mechanism for vllm-plugin-FL dispatch. + +This module provides functionality to discover and load external plugins +that can register additional operator implementations. +""" + +from __future__ import annotations + +import importlib +import os +from typing import Any, Callable, List, Optional, Tuple + +from .logger_manager import get_logger + +# Entry point group name for plugin discovery +PLUGIN_GROUP = "vllm_fl.plugin" + +# Environment variable for specifying plugin modules +PLUGIN_MODULES_ENV = "VLLM_FL_PLUGIN_MODULES" + +logger = get_logger() + +# Track discovered plugins: (name, source, success) +_discovered_plugins: List[Tuple[str, str, bool]] = [] + + +def _get_entry_points(): + """ + Get entry points for the plugin group. + + Returns: + List of entry points for the vllm_fl.plugin group + """ + try: + from importlib.metadata import entry_points + except ImportError: + try: + from importlib_metadata import entry_points + except ImportError: + logger.debug( + "importlib.metadata not available, skipping entry points discovery" + ) + return [] + + try: + eps = entry_points() + + # Python 3.10+ style + if hasattr(eps, "select"): + return list(eps.select(group=PLUGIN_GROUP)) + + # Python 3.9 style (dict-like) + if isinstance(eps, dict): + return eps.get(PLUGIN_GROUP, []) + + # Fallback for older versions + if hasattr(eps, "get"): + return eps.get(PLUGIN_GROUP, []) + + return [] + + except Exception as e: + logger.warning(f"Error accessing entry points: {e}") + return [] + + +def _call_register_function( + obj: Any, + registry: Any, + source_name: str, +) -> bool: + """ + Call the register function on a plugin object. + + Args: + obj: Plugin object (module or callable) + registry: OpRegistry instance to register into + source_name: Name of the plugin source for logging + + Returns: + True if registration was successful, False otherwise + """ + # If obj is directly callable (not a class), call it + if callable(obj) and not isinstance(obj, type): + try: + obj(registry) + logger.info(f"Registered plugin from {source_name} (direct callable)") + return True + except Exception as e: + logger.error(f"Error calling plugin {source_name}: {e}") + return False + + # Look for register function + register_fn = getattr(obj, "vllm_fl_register", None) or getattr( + obj, "register", None + ) + + if callable(register_fn): + try: + register_fn(registry) + logger.info(f"Registered plugin from {source_name}") + return True + except Exception as e: + logger.error(f"Error calling register function in {source_name}: {e}") + return False + + logger.debug(f"No register function found in {source_name}") + return False + + +def discover_from_entry_points(registry: Any) -> int: + """ + Discover and load plugins from entry points. + + Args: + registry: OpRegistry instance to register into + + Returns: + Number of successfully loaded plugins + """ + loaded = 0 + entry_points_list = _get_entry_points() + + if not entry_points_list: + logger.debug(f"No entry points found for group: {PLUGIN_GROUP}") + return 0 + + logger.debug(f"Found {len(entry_points_list)} entry points") + + for ep in entry_points_list: + ep_name = getattr(ep, "name", str(ep)) + try: + logger.debug(f"Loading entry point: {ep_name}") + obj = ep.load() + + if _call_register_function(obj, registry, f"entry_point:{ep_name}"): + _discovered_plugins.append((ep_name, "entry_point", True)) + loaded += 1 + else: + _discovered_plugins.append((ep_name, "entry_point", False)) + + except Exception as e: + logger.error(f"Failed to load entry point {ep_name}: {e}") + _discovered_plugins.append((ep_name, "entry_point", False)) + + return loaded + + +def discover_from_env_modules(registry: Any) -> int: + """ + Discover and load plugins from environment variable. + + The VLLM_FL_PLUGIN_MODULES environment variable should contain + a comma-separated list of module names to import. + + Args: + registry: OpRegistry instance to register into + + Returns: + Number of successfully loaded plugins + """ + modules_str = os.environ.get(PLUGIN_MODULES_ENV, "").strip() + + if not modules_str: + return 0 + + loaded = 0 + module_names = [m.strip() for m in modules_str.split(",") if m.strip()] + + logger.debug(f"Loading plugins from env var: {module_names}") + + for mod_name in module_names: + try: + logger.debug(f"Importing module: {mod_name}") + mod = importlib.import_module(mod_name) + + if _call_register_function(mod, registry, f"env_module:{mod_name}"): + _discovered_plugins.append((mod_name, "env_module", True)) + loaded += 1 + else: + _discovered_plugins.append((mod_name, "env_module", False)) + + except ImportError as e: + logger.error(f"Failed to import plugin module {mod_name}: {e}") + _discovered_plugins.append((mod_name, "env_module", False)) + except Exception as e: + logger.error(f"Error loading plugin module {mod_name}: {e}") + _discovered_plugins.append((mod_name, "env_module", False)) + + return loaded + + +def discover_plugins(registry: Any) -> int: + """ + Main plugin discovery function. + + Discovers and registers plugins from: + 1. Entry points (group: 'vllm_fl.plugin') + 2. Environment variable modules (VLLM_FL_PLUGIN_MODULES) + + Args: + registry: OpRegistry instance to register plugins to + + Returns: + Number of successfully loaded plugins + """ + if registry is None: + logger.warning("Registry is None, skipping plugin discovery") + return 0 + + logger.debug("Starting plugin discovery...") + + total = 0 + + # Discover from entry points + total += discover_from_entry_points(registry) + + # Discover from environment variable + total += discover_from_env_modules(registry) + + logger.debug(f"Plugin discovery complete. Loaded {total} plugins.") + + return total + + +def get_discovered_plugins() -> List[Tuple[str, str, bool]]: + """ + Get list of discovered plugins. + + Returns: + List of tuples (name, source, success) + """ + return _discovered_plugins.copy() + + +def clear_discovered_plugins() -> None: + """Clear the discovered plugins list (for testing).""" + _discovered_plugins.clear() diff --git a/vllm_fl/dispatch/logger_manager.py b/vllm_fl/dispatch/logger_manager.py new file mode 100644 index 00000000..0147b6ba --- /dev/null +++ b/vllm_fl/dispatch/logger_manager.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Logger manager for vllm-plugin-FL dispatch. + +Provides centralized logging configuration for the dispatch module. +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +# Default log level from environment variable +_DEFAULT_LOG_LEVEL = os.environ.get("VLLM_FL_LOG_LEVEL", "INFO").upper() + +# Module-level logger cache +_loggers: dict[str, logging.Logger] = {} + + +def get_logger(name: str = "vllm_fl.dispatch") -> logging.Logger: + """ + Get a logger instance for the dispatch module. + + Args: + name: Logger name, defaults to "vllm_fl.dispatch" + + Returns: + Configured logger instance + """ + if name in _loggers: + return _loggers[name] + + logger = logging.getLogger(name) + + # Only configure if no handlers exist + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # Set log level from environment + level = getattr(logging, _DEFAULT_LOG_LEVEL, logging.INFO) + logger.setLevel(level) + + _loggers[name] = logger + return logger + + +def set_log_level(level: str, name: Optional[str] = None) -> None: + """ + Set the log level for dispatch loggers. + + Args: + level: Log level string (DEBUG, INFO, WARNING, ERROR, CRITICAL) + name: Optional logger name, if None sets for all cached loggers + """ + log_level = getattr(logging, level.upper(), logging.INFO) + + if name is not None: + if name in _loggers: + _loggers[name].setLevel(log_level) + else: + for logger in _loggers.values(): + logger.setLevel(log_level) diff --git a/vllm_fl/dispatch/manager.py b/vllm_fl/dispatch/manager.py new file mode 100644 index 00000000..1e858654 --- /dev/null +++ b/vllm_fl/dispatch/manager.py @@ -0,0 +1,507 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Core operator dispatch manager. +""" + +from __future__ import annotations + +import logging +import os +import threading +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +from .registry import OpRegistry +from .policy import SelectionPolicy, get_policy +from .types import OpImpl, BackendImplKind, match_token + + +logger = logging.getLogger(__name__) + +# Debug printing control +_DISPATCH_DEBUG = os.getenv("VLLM_FL_DISPATCH_DEBUG", "0") == "1" + + +@dataclass +class _OpManagerState: + """Internal state for OpManager.""" + init_pid: int = -1 + initialized: bool = False + policy_epoch: int = 0 + + +class OpManager: + """ + Main manager for operator dispatching and selection. + + Responsibilities: + - Lazy initialization and plugin discovery + - Multi-process safety (PID detection + at_fork) + - Policy-based operator selection + - Dispatch caching with invalidation + """ + + def __init__(self, registry: Optional[OpRegistry] = None) -> None: + self._lock = threading.RLock() + self._registry = registry or OpRegistry() + self._state = _OpManagerState() + self._dispatch_cache: Dict[Tuple[str, str, int], Callable] = {} + self._called_ops: Dict[str, str] = {} # Map op_name -> last_used_impl_id + + # Register at_fork handler for multi-process safety + try: + os.register_at_fork(after_in_child=self._reset_after_fork) + except AttributeError: + # os.register_at_fork not available (Windows) + pass + + @property + def registry(self) -> OpRegistry: + """Get the underlying operator registry.""" + return self._registry + + def _reset_after_fork(self) -> None: + """Reset state after process fork.""" + with self._lock: + self._state.initialized = False + self._state.init_pid = -1 + self._state.policy_epoch += 1 + self._dispatch_cache.clear() + self._called_ops.clear() + logger.debug("OpManager reset after fork") + + def bump_policy_epoch(self) -> None: + """ + Increment policy epoch to invalidate dispatch cache. + + Call this when policy changes at runtime. + """ + with self._lock: + self._state.policy_epoch += 1 + self._dispatch_cache.clear() + logger.debug(f"Policy epoch bumped to {self._state.policy_epoch}") + + def ensure_initialized(self) -> None: + """ + Ensure the manager is initialized in the current process. + + Performs: + 1. PID check (multi-process safety) + 2. Register built-in operator implementations + """ + with self._lock: + pid = os.getpid() + + # Check if already initialized in this process + if self._state.initialized and self._state.init_pid == pid: + return + + logger.debug(f"Initializing OpManager in PID {pid}") + + # Mark as initialized + self._state.initialized = True + self._state.init_pid = pid + + # Register built-in operators + from . import builtin_ops + builtin_ops.register_builtins(self._registry) + + # Invalidate cache + self._state.policy_epoch += 1 + self._dispatch_cache.clear() + + # Print initialization summary + snap = self._registry.snapshot() + total_ops = len(snap.impls_by_op) + total_impls = sum(len(impls) for impls in snap.impls_by_op.values()) + + logger.info(f"OpManager initialized: {total_ops} ops with {total_impls} implementations") + + # Group implementations by kind for summary + vendor_count = sum(1 for impls in snap.impls_by_op.values() + for impl in impls if impl.kind == BackendImplKind.VENDOR) + reference_count = sum(1 for impls in snap.impls_by_op.values() + for impl in impls if impl.kind == BackendImplKind.REFERENCE) + default_count = sum(1 for impls in snap.impls_by_op.values() + for impl in impls if impl.kind == BackendImplKind.DEFAULT) + + logger.debug(f" Vendor: {vendor_count}, Default: {default_count}, Reference: {reference_count}") + + # Print detailed operator list if debug is enabled + if _DISPATCH_DEBUG: + self._print_registered_operators() + + def _print_registered_operators(self) -> None: + """Print detailed list of registered operators and their implementations.""" + snap = self._registry.snapshot() + + print("\n" + "="*80) + print("VLLM-FL Dispatch: Registered Operators") + print("="*80) + + # Sort operators by name for consistent output + sorted_ops = sorted(snap.impls_by_op.items()) + + for op_name, impls in sorted_ops: + print(f"\n[Operator: {op_name}]") + # Sort implementations by priority (highest first) + sorted_impls = sorted(impls, key=lambda x: (x.priority, x.impl_id), reverse=True) + + for impl in sorted_impls: + available = "✓" if impl.is_available() else "✗" + vendor_info = f", vendor={impl.vendor}" if impl.vendor else "" + print(f" {available} {impl.impl_id} (kind={impl.kind.value}, priority={impl.priority}{vendor_info})") + + print("\n" + "="*80 + "\n") + + def _matches_vendor_filters(self, impl: OpImpl, policy: SelectionPolicy) -> bool: + """Check if implementation matches policy vendor filters.""" + if impl.kind != BackendImplKind.VENDOR: + return True + + if impl.vendor is None: + return False + + # Check deny list + if impl.vendor in policy.deny_vendors: + return False + + # Check allow list (if specified) + if policy.allow_vendors is not None and impl.vendor not in policy.allow_vendors: + return False + + return True + + def _default_order(self, policy: SelectionPolicy) -> list[str]: + """Get default selection order based on policy.""" + return policy.get_default_order() + + def resolve(self, op_name: str) -> Callable: + """ + Resolve and return the best implementation for an operator. + + Selection process: + 1. Check dispatch cache + 2. Get all registered implementations + 3. Filter by policy (vendor allow/deny) + 4. Filter by availability (is_available()) + 5. Select best match using per-op order or default order + 6. Cache the result + + Args: + op_name: Name of the operator to resolve + + Returns: + Callable implementation function + + Raises: + RuntimeError: If no implementation found + """ + self.ensure_initialized() + + policy = get_policy() + policy_fp = policy.fingerprint() + epoch = self._state.policy_epoch + + # Check cache + cache_key = (op_name, policy_fp, epoch) + cached = self._dispatch_cache.get(cache_key) + if cached is not None: + return cached + + # Get all implementations for this operator + snap = self._registry.snapshot() + candidates = list(snap.impls_by_op.get(op_name, [])) + + # Filter by vendor policy + candidates = [c for c in candidates if self._matches_vendor_filters(c, policy)] + + # Filter by availability + available: list[OpImpl] = [] + for c in candidates: + try: + if c.is_available(): + available.append(c) + else: + logger.debug(f"Implementation {c.impl_id} not available for op={op_name}") + except Exception as e: + logger.warning(f"Error checking availability of {c.impl_id}: {e}") + continue + + candidates = available + + if not candidates: + raise RuntimeError( + f"No available implementation for op='{op_name}'. " + f"Registered: {[impl.impl_id for impl in snap.impls_by_op.get(op_name, [])]}" + ) + + # Get selection order (per-op or default) + order = policy.per_op_order_dict.get(op_name) or self._default_order(policy) + + # Select best implementation + chosen: Optional[OpImpl] = None + for token in order: + matches = [c for c in candidates if match_token(c, token)] + if not matches: + continue + + # Sort by priority (higher first), then by impl_id for stability + matches.sort(key=lambda x: (x.priority, x.impl_id), reverse=True) + chosen = matches[0] + break + + if chosen is None: + if policy.strict: + raise RuntimeError( + f"No implementation available for op='{op_name}' under strict policy. " + f"Candidates: {[c.impl_id for c in candidates]}" + ) + raise RuntimeError( + f"No implementation selected for op='{op_name}'. " + f"Candidates: {[c.impl_id for c in candidates]}, Order: {order}" + ) + + # Cache the result + self._dispatch_cache[cache_key] = chosen.fn + + # Print selected backend if debug is enabled + if _DISPATCH_DEBUG: + vendor_info = f", vendor={chosen.vendor}" if chosen.vendor else "" + print(f"[DISPATCH] Op '{op_name}' -> '{chosen.impl_id}' (kind={chosen.kind.value}{vendor_info})") + + return chosen.fn + + def resolve_candidates(self, op_name: str) -> list[OpImpl]: + """ + Resolve and return all available implementations for an operator, + sorted by priority (highest first). + + This is similar to resolve() but returns all viable candidates + instead of just the best one. Useful for fallback mechanisms. + + Args: + op_name: Name of the operator to resolve + + Returns: + List of OpImpl sorted by priority (highest first) + + Raises: + RuntimeError: If no implementation found + """ + self.ensure_initialized() + + policy = get_policy() + + # Get all implementations for this operator + snap = self._registry.snapshot() + candidates = list(snap.impls_by_op.get(op_name, [])) + + # Filter by vendor policy + candidates = [c for c in candidates if self._matches_vendor_filters(c, policy)] + + # Filter by availability + available: list[OpImpl] = [] + for c in candidates: + try: + if c.is_available(): + available.append(c) + else: + logger.debug(f"Implementation {c.impl_id} not available for op={op_name}") + except Exception as e: + logger.warning(f"Error checking availability of {c.impl_id}: {e}") + continue + + candidates = available + + if not candidates: + raise RuntimeError( + f"No available implementation for op='{op_name}'. " + f"Registered: {[impl.impl_id for impl in snap.impls_by_op.get(op_name, [])]}" + ) + + # Get selection order (per-op or default) + order = policy.per_op_order_dict.get(op_name) or self._default_order(policy) + + # Sort candidates by order tokens, then by priority + sorted_candidates: list[OpImpl] = [] + for token in order: + matches = [c for c in candidates if match_token(c, token)] + if matches: + # Sort by priority (higher first), then by impl_id for stability + matches.sort(key=lambda x: (x.priority, x.impl_id), reverse=True) + sorted_candidates.extend(matches) + + # Remove duplicates while preserving order + seen = set() + unique_candidates = [] + for c in sorted_candidates: + if c.impl_id not in seen: + seen.add(c.impl_id) + unique_candidates.append(c) + + if not unique_candidates: + raise RuntimeError( + f"No implementation selected for op='{op_name}'. " + f"Candidates: {[c.impl_id for c in candidates]}, Order: {order}" + ) + + return unique_candidates + + def call(self, op_name: str, *args, **kwargs): + """ + Resolve and call an operator implementation with optional fallback support. + + When VLLM_FL_STRICT=1, this method will try alternative implementations + if the primary one fails. Otherwise, it behaves like the original implementation. + + Logs on first call or when the implementation changes (e.g., backend switch). + + Args: + op_name: Name of the operator + *args, **kwargs: Arguments passed to the implementation + + Returns: + Result from the implementation + + Raises: + RuntimeError: If all implementations fail (when fallback enabled) or + if the primary implementation fails (when fallback disabled) + """ + enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" + + if not enable_fallback: + # Original behavior: use cached resolve() and fast-fail + fn = self.resolve(op_name) + + # Get current impl_id to check if it changed + impl_id = self.get_selected_impl_id(op_name) + last_impl_id = self._called_ops.get(op_name) + + # Log if first call or implementation changed + if last_impl_id != impl_id: + with self._lock: + # Double-check after acquiring lock + if self._called_ops.get(op_name) != impl_id: + snap = self._registry.snapshot() + for impl in snap.impls_by_op.get(op_name, []): + if impl.impl_id == impl_id: + if last_impl_id is None: + logger.info( + f"Op '{op_name}' using '{impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + break + self._called_ops[op_name] = impl_id + + return fn(*args, **kwargs) + + # Fallback mode: try candidates in priority order + candidates = self.resolve_candidates(op_name) + last_error = None + + for idx, impl in enumerate(candidates): + try: + # Log primary implementation or fallback attempts + if idx == 0: + # Primary implementation + last_impl_id = self._called_ops.get(op_name) + if last_impl_id != impl.impl_id: + with self._lock: + if self._called_ops.get(op_name) != impl.impl_id: + if last_impl_id is None: + logger.info( + f"Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + self._called_ops[op_name] = impl.impl_id + else: + # Always log fallback attempts (these are important runtime events) + logger.info( + f"Op '{op_name}' fallback to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + + result = impl.fn(*args, **kwargs) + + # Update tracked impl_id on success (for fallback case) + if idx > 0: + with self._lock: + self._called_ops[op_name] = impl.impl_id + + return result + + except Exception as e: + last_error = e + if idx < len(candidates) - 1: + # Not the last candidate, log warning and try next + logger.warning( + f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + else: + # Last candidate failed, log error + logger.error( + f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + + # All implementations failed + raise RuntimeError( + f"All {len(candidates)} implementation(s) failed for op='{op_name}'. " + f"Last error: {last_error}" + ) from last_error + + def get_selected_impl_id(self, op_name: str) -> str: + """ + Get the impl_id of the currently selected implementation. + + Args: + op_name: Name of the operator + + Returns: + Implementation ID string + """ + fn = self.resolve(op_name) + + # Try to find the impl by function identity + snap = self._registry.snapshot() + for impl in snap.impls_by_op.get(op_name, []): + if impl.fn is fn: + return impl.impl_id + + return "unknown" + + +# Global default instance +_default_manager: Optional[OpManager] = None +_manager_lock = threading.RLock() + + +def get_default_manager() -> OpManager: + """Get or create the global default OpManager instance.""" + global _default_manager + + if _default_manager is None: + with _manager_lock: + if _default_manager is None: + _default_manager = OpManager() + + return _default_manager + + +def reset_default_manager() -> None: + """Reset the global default OpManager (useful for testing).""" + global _default_manager + + with _manager_lock: + _default_manager = None diff --git a/vllm_fl/dispatch/ops.py b/vllm_fl/dispatch/ops.py new file mode 100644 index 00000000..b2b87389 --- /dev/null +++ b/vllm_fl/dispatch/ops.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Backend base interface definitions for vllm-plugin-FL dispatch. + +This module defines the abstract base class that all backends must implement. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional, Union + +import torch + + +class VLLMFLBackendBase(ABC): + """ + Abstract base class for vllm-plugin-FL operator backends. + + Each backend provides implementations for a set of operators. + Backends should implement is_available() to indicate whether + the backend can be used in the current environment. + + All operator methods should be implemented by concrete backend classes. + Methods that are not supported should raise NotImplementedError. + """ + + @abstractmethod + def is_available(self) -> bool: + """ + Check if this backend is available in the current environment. + + Returns: + True if the backend can be used, False otherwise. + """ + pass + + @property + @abstractmethod + def name(self) -> str: + """ + Get the name of this backend. + + Returns: + Backend name string. + """ + pass + + @property + def vendor(self) -> Optional[str]: + """ + Get the vendor name for this backend (if applicable). + + Returns: + Vendor name string, or None for non-vendor backends. + """ + return None + + # ==================== Activation Operators ==================== + + @abstractmethod + def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + pass + + # ==================== Normalization Operators ==================== + + @abstractmethod + def rmsnorm( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization. + + Args: + x: Input tensor + residual: Optional residual tensor + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + pass + + # ==================== Position Embedding Operators ==================== + + @abstractmethod + def rotary_embedding( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + pass diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py new file mode 100644 index 00000000..b93b8bd7 --- /dev/null +++ b/vllm_fl/dispatch/policy.py @@ -0,0 +1,405 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Selection policy management for operator dispatch. +""" + +from __future__ import annotations + +import contextvars +import os +import threading +from dataclasses import dataclass, field +from typing import Dict, FrozenSet, List, Optional, Set, Tuple + + +# Valid preference values for VLLM_FL_PREFER +PREFER_DEFAULT = "flaggems" +PREFER_VENDOR = "vendor" +PREFER_REFERENCE = "reference" + +VALID_PREFER_VALUES = frozenset({PREFER_DEFAULT, PREFER_VENDOR, PREFER_REFERENCE}) + + +@dataclass(frozen=True) +class SelectionPolicy: + """ + Policy for selecting operator implementations. + + Attributes: + prefer: Which implementation kind to prefer. One of: + - "flaggems": Prefer DEFAULT (FlagGems) implementations + - "vendor": Prefer VENDOR (CUDA) implementations + - "reference": Prefer REFERENCE (PyTorch) implementations + strict: If True, raise error when primary implementation fails + per_op_order: Per-operator custom selection order + deny_vendors: Set of vendor names to deny + allow_vendors: Set of vendor names to allow (whitelist) + """ + prefer: str = PREFER_DEFAULT + strict: bool = False + per_op_order: Tuple[Tuple[str, Tuple[str, ...]], ...] = field(default_factory=tuple) + deny_vendors: FrozenSet[str] = field(default_factory=frozenset) + allow_vendors: Optional[FrozenSet[str]] = None + + def __post_init__(self): + if self.prefer not in VALID_PREFER_VALUES: + raise ValueError( + f"Invalid prefer value: '{self.prefer}'. " + f"Must be one of: {', '.join(sorted(VALID_PREFER_VALUES))}" + ) + + @classmethod + def from_dict( + cls, + prefer: str = PREFER_DEFAULT, + strict: bool = False, + per_op_order: Optional[Dict[str, List[str]]] = None, + deny_vendors: Optional[Set[str]] = None, + allow_vendors: Optional[Set[str]] = None, + ) -> "SelectionPolicy": + """Create a SelectionPolicy from dictionary-like arguments.""" + per_op_tuple = tuple() + if per_op_order: + per_op_tuple = tuple( + (k, tuple(v)) for k, v in sorted(per_op_order.items()) + ) + + return cls( + prefer=prefer.lower(), + strict=strict, + per_op_order=per_op_tuple, + deny_vendors=frozenset(deny_vendors) if deny_vendors else frozenset(), + allow_vendors=frozenset(allow_vendors) if allow_vendors else None, + ) + + @property + def per_op_order_dict(self) -> Dict[str, List[str]]: + """Get per_op_order as a mutable dict for easier access.""" + return {k: list(v) for k, v in self.per_op_order} + + def get_per_op_order(self, op_name: str) -> Optional[List[str]]: + """Get order for a specific operator.""" + for name, order in self.per_op_order: + if name == op_name: + return list(order) + return None + + def get_default_order(self) -> List[str]: + """Get the default selection order based on preference setting.""" + if self.prefer == PREFER_REFERENCE: + return ["reference", "flaggems", "vendor"] + elif self.prefer == PREFER_VENDOR: + return ["vendor", "flaggems", "reference"] + else: # PREFER_DEFAULT + return ["flaggems", "vendor", "reference"] + + def is_vendor_allowed(self, vendor_name: str) -> bool: + """Check if a vendor is allowed by this policy.""" + if vendor_name in self.deny_vendors: + return False + if self.allow_vendors is not None and vendor_name not in self.allow_vendors: + return False + return True + + def fingerprint(self) -> str: + """Generate a unique fingerprint for this policy (used for caching).""" + parts = [ + f"prefer={self.prefer}", + f"st={int(self.strict)}", + ] + + if self.allow_vendors: + parts.append(f"allow={','.join(sorted(self.allow_vendors))}") + + if self.deny_vendors: + parts.append(f"deny={','.join(sorted(self.deny_vendors))}") + + if self.per_op_order: + per_op_str = ";".join( + f"{k}={'|'.join(v)}" for k, v in self.per_op_order + ) + parts.append(f"per={per_op_str}") + + return ";".join(parts) + + def __hash__(self) -> int: + return hash(( + self.prefer, + self.strict, + self.per_op_order, + self.deny_vendors, + self.allow_vendors, + )) + + +class PolicyManager: + """ + Singleton manager for selection policies. + + Supports: + - Global policy (from environment or set programmatically) + - Context-local policy (using context managers) + - Policy epoch tracking for cache invalidation + """ + _instance = None + _lock = threading.Lock() + + def __init__(self): + if hasattr(self, '_policy_epoch'): + return + + self._policy_epoch = 0 + self._policy_epoch_lock = threading.Lock() + self._global_policy = None + self._global_policy_lock = threading.Lock() + + self._policy_var = contextvars.ContextVar( + "vllm_fl_selection_policy", + default=None, + ) + + @classmethod + def get_instance(cls): + """Get the singleton instance.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.__init__() + return cls._instance + + def get_policy_epoch(self) -> int: + """Get the current policy epoch.""" + return self._policy_epoch + + def bump_policy_epoch(self) -> int: + """Bump the policy epoch and return the new value.""" + with self._policy_epoch_lock: + self._policy_epoch += 1 + return self._policy_epoch + + def get_policy(self) -> SelectionPolicy: + """Get the current effective policy (context or global).""" + ctx_policy = self._policy_var.get() + if ctx_policy is not None: + return ctx_policy + + if self._global_policy is None: + with self._global_policy_lock: + if self._global_policy is None: + self._global_policy = self._policy_from_env() + return self._global_policy + + def set_global_policy(self, policy: SelectionPolicy) -> SelectionPolicy: + """Set the global policy and return the old policy.""" + with self._global_policy_lock: + old_policy = self._global_policy + self._global_policy = policy + self.bump_policy_epoch() + return old_policy if old_policy else self._policy_from_env() + + def reset_global_policy(self) -> None: + """Reset the global policy to environment defaults.""" + with self._global_policy_lock: + self._global_policy = None + self.bump_policy_epoch() + + def create_policy_context(self, policy: SelectionPolicy): + """Create a context manager for temporary policy override.""" + return _PolicyContext(self, policy) + + def _get_policy_var(self): + return self._policy_var + + @staticmethod + def _parse_csv_set(value: str) -> Set[str]: + """Parse a comma-separated string into a set.""" + if not value: + return set() + return {x.strip() for x in value.split(",") if x.strip()} + + @staticmethod + def _parse_per_op(value: str) -> Dict[str, List[str]]: + """Parse per-op order string (format: op1=a|b|c;op2=x|y).""" + if not value: + return {} + + result: Dict[str, List[str]] = {} + parts = [p.strip() for p in value.split(";") if p.strip()] + + for part in parts: + if "=" not in part: + continue + op_name, order_str = part.split("=", 1) + op_name = op_name.strip() + order = [x.strip() for x in order_str.split("|") if x.strip()] + if op_name and order: + result[op_name] = order + + return result + + def _policy_from_env(self) -> SelectionPolicy: + """ + Create a SelectionPolicy from environment variables. + + Environment variables: + - VLLM_FL_PREFER: Preference (flaggems, vendor, reference) + - VLLM_FL_STRICT: Enable strict mode (1 or 0) + - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors + - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors + - VLLM_FL_PER_OP: Per-op order (format: op1=a|b|c;op2=x|y) + """ + prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() + if prefer_str and prefer_str in VALID_PREFER_VALUES: + pass + else: + prefer_str = PREFER_DEFAULT + + strict = os.environ.get("VLLM_FL_STRICT", "0").strip() == "1" + + deny_str = os.environ.get("VLLM_FL_DENY_VENDORS", "").strip() + deny_vendors = self._parse_csv_set(deny_str) if deny_str else None + + allow_str = os.environ.get("VLLM_FL_ALLOW_VENDORS", "").strip() + allow_vendors = self._parse_csv_set(allow_str) if allow_str else None + + per_op_str = os.environ.get("VLLM_FL_PER_OP", "").strip() + per_op_order = self._parse_per_op(per_op_str) if per_op_str else None + + return SelectionPolicy.from_dict( + prefer=prefer_str, + strict=strict, + per_op_order=per_op_order, + deny_vendors=deny_vendors, + allow_vendors=allow_vendors, + ) + + +class _PolicyContext: + """Context manager for temporary policy override.""" + + def __init__(self, manager: PolicyManager, policy: SelectionPolicy): + self._manager = manager + self._policy = policy + self._token: Optional[contextvars.Token] = None + + def __enter__(self) -> "_PolicyContext": + policy_var = self._manager._get_policy_var() + self._token = policy_var.set(self._policy) + self._manager.bump_policy_epoch() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self._token is not None: + policy_var = self._manager._get_policy_var() + policy_var.reset(self._token) + self._manager.bump_policy_epoch() + + +# Convenience functions for easier access +def get_policy_epoch() -> int: + """Get the current policy epoch.""" + return PolicyManager.get_instance().get_policy_epoch() + + +def bump_policy_epoch() -> int: + """Bump the policy epoch and return the new value.""" + return PolicyManager.get_instance().bump_policy_epoch() + + +def get_policy() -> SelectionPolicy: + """Get the current effective policy (context or global).""" + return PolicyManager.get_instance().get_policy() + + +def set_global_policy(policy: SelectionPolicy) -> SelectionPolicy: + """Set the global policy and return the old policy.""" + return PolicyManager.get_instance().set_global_policy(policy) + + +def reset_global_policy() -> None: + """Reset the global policy to environment defaults.""" + PolicyManager.get_instance().reset_global_policy() + + +def policy_from_env() -> SelectionPolicy: + """Create a SelectionPolicy from environment variables.""" + return PolicyManager.get_instance()._policy_from_env() + + +def policy_context(policy: SelectionPolicy) -> _PolicyContext: + """ + Create a context manager to temporarily override the policy. + + Example: + >>> with policy_context(my_policy): + ... # Use my_policy in this context + ... result = manager.resolve("op_name") + """ + return _PolicyContext(PolicyManager.get_instance(), policy) + + +# Convenience context managers +def with_strict_mode() -> _PolicyContext: + """Context manager to enable strict mode.""" + current = get_policy() + strict_policy = SelectionPolicy.from_dict( + prefer=current.prefer, + strict=True, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=set(current.deny_vendors), + allow_vendors=set(current.allow_vendors) if current.allow_vendors else None, + ) + return policy_context(strict_policy) + + +def with_preference(prefer: str) -> _PolicyContext: + """ + Context manager to set implementation preference. + + Args: + prefer: One of "flaggems", "vendor", or "reference" + + Example: + >>> with with_preference("vendor"): + ... # Prefer vendor implementations in this context + ... result = manager.resolve("op_name") + """ + current = get_policy() + policy = SelectionPolicy.from_dict( + prefer=prefer, + strict=current.strict, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=set(current.deny_vendors), + allow_vendors=set(current.allow_vendors) if current.allow_vendors else None, + ) + return policy_context(policy) + + +def with_allowed_vendors(*vendors: str) -> _PolicyContext: + """Context manager to set allowed vendors whitelist.""" + current = get_policy() + policy = SelectionPolicy.from_dict( + prefer=current.prefer, + strict=current.strict, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=set(current.deny_vendors), + allow_vendors=set(vendors), + ) + return policy_context(policy) + + +def with_denied_vendors(*vendors: str) -> _PolicyContext: + """Context manager to add denied vendors to blacklist.""" + current = get_policy() + denied = set(current.deny_vendors) + denied.update(vendors) + policy = SelectionPolicy.from_dict( + prefer=current.prefer, + strict=current.strict, + per_op_order={k: list(v) for k, v in current.per_op_order}, + deny_vendors=denied, + allow_vendors=set(current.allow_vendors) if current.allow_vendors else None, + ) + return policy_context(policy) diff --git a/vllm_fl/dispatch/registry.py b/vllm_fl/dispatch/registry.py new file mode 100644 index 00000000..8d3ee71a --- /dev/null +++ b/vllm_fl/dispatch/registry.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Thread-safe registry for operator implementations. +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass +from typing import Dict, List, Sequence, Optional + +from .types import OpImpl + + +@dataclass +class OpRegistrySnapshot: + """Immutable snapshot of operator registry state.""" + impls_by_op: Dict[str, List[OpImpl]] + + +class OpRegistry: + """ + Thread-safe registry for operator implementations. + + This registry stores operator implementations indexed by op_name and impl_id. + Each operator can have multiple implementations from different backends/vendors. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + # Structure: {op_name: {impl_id: OpImpl}} + self._impls_by_op: Dict[str, Dict[str, OpImpl]] = {} + + def register_impl(self, impl: OpImpl) -> None: + """ + Register a single operator implementation. + + Args: + impl: OpImpl instance to register + + Raises: + ValueError: If impl_id is already registered for this op_name + """ + with self._lock: + by_id = self._impls_by_op.setdefault(impl.op_name, {}) + if impl.impl_id in by_id: + raise ValueError( + f"Duplicate impl_id '{impl.impl_id}' for op='{impl.op_name}'. " + f"Existing: {by_id[impl.impl_id]}, New: {impl}" + ) + by_id[impl.impl_id] = impl + + def register_many(self, impls: Sequence[OpImpl]) -> None: + """ + Register multiple operator implementations. + + Args: + impls: Sequence of OpImpl instances to register + """ + for impl in impls: + self.register_impl(impl) + + def snapshot(self) -> OpRegistrySnapshot: + """ + Create an immutable snapshot of current registry state. + + Returns: + OpRegistrySnapshot with all registered implementations + """ + with self._lock: + impls_by_op = { + op: list(by_id.values()) + for op, by_id in self._impls_by_op.items() + } + return OpRegistrySnapshot(impls_by_op=impls_by_op) + + def get_implementations(self, op_name: str) -> List[OpImpl]: + """ + Get all implementations for a specific operator. + + Args: + op_name: Name of the operator + + Returns: + List of OpImpl for the operator (empty if not found) + """ + with self._lock: + by_id = self._impls_by_op.get(op_name, {}) + return list(by_id.values()) + + def get_implementation(self, op_name: str, impl_id: str) -> Optional[OpImpl]: + """ + Get a specific implementation by op_name and impl_id. + + Args: + op_name: Name of the operator + impl_id: Implementation ID + + Returns: + OpImpl if found, None otherwise + """ + with self._lock: + by_id = self._impls_by_op.get(op_name, {}) + return by_id.get(impl_id) + + def list_operators(self) -> List[str]: + """ + List all registered operator names. + + Returns: + List of operator names + """ + with self._lock: + return list(self._impls_by_op.keys()) + + def clear(self) -> None: + """Clear all registered implementations.""" + with self._lock: + self._impls_by_op.clear() diff --git a/vllm_fl/dispatch/types.py b/vllm_fl/dispatch/types.py new file mode 100644 index 00000000..1afe792f --- /dev/null +++ b/vllm_fl/dispatch/types.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Core type definitions for the dispatch mechanism. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional, Set + + +class BackendImplKind(str, Enum): + """ + Kind of backend implementation. + + - DEFAULT: Default implementation (FlagGems) + - REFERENCE: Reference implementation (PyTorch native) + - VENDOR: Vendor-specific implementation (CUDA, etc.) + """ + DEFAULT = "flaggems" + REFERENCE = "reference" + VENDOR = "vendor" + + def __str__(self) -> str: + return self.value + + +@dataclass(frozen=True) +class OpImpl: + """ + Operator implementation descriptor. + + Attributes: + op_name: Name of the operator (e.g., "silu_and_mul", "rmsnorm") + impl_id: Unique identifier for this implementation (e.g., "default.flaggems") + kind: Type of implementation (DEFAULT, REFERENCE, VENDOR) + fn: The actual implementation function + vendor: Vendor name (required if kind is VENDOR) + priority: Priority for selection (higher = preferred) + supported_dtypes: Set of supported data types (optional) + min_arch: Minimum architecture requirement (optional) + """ + op_name: str + impl_id: str + kind: BackendImplKind + fn: Callable[..., Any] + vendor: Optional[str] = None + priority: int = 0 + supported_dtypes: Optional[Set[str]] = None + min_arch: Optional[str] = None + + def __post_init__(self): + if self.kind == BackendImplKind.VENDOR and not self.vendor: + raise ValueError( + f"OpImpl with kind=VENDOR must specify vendor name: {self.impl_id}" + ) + + def is_available(self) -> bool: + """ + Check if this implementation is available. + + Looks for a _is_available attribute on the function. + """ + avail_fn = getattr(self.fn, "_is_available", None) + if callable(avail_fn): + try: + return bool(avail_fn()) + except Exception: + return False + return True + + +# Token patterns for matching implementations +TOKEN_PATTERNS = { + "flaggems": lambda impl: impl.kind == BackendImplKind.DEFAULT, + "reference": lambda impl: impl.kind == BackendImplKind.REFERENCE, + "vendor": lambda impl: impl.kind == BackendImplKind.VENDOR, +} + + +def match_token(impl: OpImpl, token: str) -> bool: + """ + Check if an implementation matches a selection token. + + Supported token formats: + - "flaggems": Match DEFAULT implementations + - "reference": Match REFERENCE implementations + - "vendor": Match any VENDOR implementation + - "vendor:CUDA": Match VENDOR with specific vendor name + - "impl:default.flaggems": Match specific impl_id + + Args: + impl: Implementation to check + token: Selection token + + Returns: + True if implementation matches the token + """ + if token in TOKEN_PATTERNS: + return TOKEN_PATTERNS[token](impl) + + if token.startswith("vendor:"): + vendor_name = token.split(":", 1)[1] + return impl.kind == BackendImplKind.VENDOR and impl.vendor == vendor_name + + if token.startswith("impl:"): + impl_id = token.split(":", 1)[1] + return impl.impl_id == impl_id + + return False diff --git a/vllm_fl/ops/activation.py b/vllm_fl/ops/activation.py index c1094d97..03378048 100644 --- a/vllm_fl/ops/activation.py +++ b/vllm_fl/ops/activation.py @@ -2,16 +2,15 @@ import torch from vllm.model_executor.layers.activation import SiluAndMul -from flag_gems.modules.activation import gems_silu_and_mul +from vllm_fl.dispatch import call_op + class SiluAndMulFL(SiluAndMul): def __init__(self): super().__init__() def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - x1, x2 = x[..., :d], x[..., d:] - return gems_silu_and_mul(x1, x2) - + return call_op("silu_and_mul", x) + __all__ = ["SiluAndMulFL"] diff --git a/vllm_fl/ops/layernorm.py b/vllm_fl/ops/layernorm.py index 6ae1dd3b..53820d32 100644 --- a/vllm_fl/ops/layernorm.py +++ b/vllm_fl/ops/layernorm.py @@ -3,10 +3,12 @@ from typing import Optional, Union import torch from vllm.model_executor.layers.layernorm import RMSNorm -from flag_gems.modules.normalization import gems_rms_forward +from vllm_fl.dispatch import call_op + class RMSNormFL(RMSNorm): - def __init__(self, + def __init__( + self, hidden_size: int, eps: float = 1e-6, var_hidden_size: Optional[int] = None, @@ -20,7 +22,8 @@ def forward_oot( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - return gems_rms_forward(x, residual, self.weight, self.variance_epsilon) - + return call_op("rmsnorm", x, residual, self.weight, self.variance_epsilon) + + __all__ = ["RMSNormFL"] diff --git a/vllm_fl/ops/rotary_embedding.py b/vllm_fl/ops/rotary_embedding.py index bba8ffc9..5fa19cd6 100644 --- a/vllm_fl/ops/rotary_embedding.py +++ b/vllm_fl/ops/rotary_embedding.py @@ -1,9 +1,10 @@ # Copyright (c) 2025 BAAI. All rights reserved. -from typing import Optional, Union +from typing import Optional import torch from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from flag_gems.modules.rotary_embedding import gems_rope_forward +from vllm_fl.dispatch import call_op + class RotaryEmbeddingFL(RotaryEmbedding): def __init__( @@ -15,9 +16,11 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, ) -> None: - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype + ) + def forward_oot( self, positions: torch.Tensor, @@ -36,19 +39,20 @@ def forward_oot( query_rot = query[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim :] - key_pass = key[..., self.rotary_dim :] + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] cos, sin = self.cos_sin_cache.chunk(2, dim=-1) - q_embed, k_embed = gems_rope_forward( + q_embed, k_embed = call_op( + "rotary_embedding", query_rot, key_rot, cos, sin, - position_ids=positions, - rotary_interleaved=not self.is_neox_style, - inplace=True, # set inplace to True for vLLM compatibility + positions, + not self.is_neox_style, # rotary_interleaved + True, # inplace ) if self.rotary_dim < self.head_size: @@ -59,5 +63,6 @@ def forward_oot( key = k_embed.reshape(key_shape) return query, key - + + __all__ = ["RotaryEmbeddingFL"] From 4a749f076ad6c3eb816b272986c583b94f8670e8 Mon Sep 17 00:00:00 2001 From: yxa Date: Mon, 12 Jan 2026 06:57:27 +0000 Subject: [PATCH 07/34] Adjusting the Vendor multi-backend structure --- vllm_fl/dispatch/README.md | 244 ++++++++++-------- vllm_fl/dispatch/__init__.py | 3 +- vllm_fl/dispatch/backends/__init__.py | 14 + .../backends/flaggems/register_ops.py | 8 +- .../backends/reference/register_ops.py | 8 +- vllm_fl/dispatch/backends/vendor/__init__.py | 33 +++ .../backends/vendor/ascend/__init__.py | 9 + .../dispatch/backends/vendor/ascend/ascend.py | 126 +++++++++ .../backends/vendor/ascend/impl/__init__.py | 15 ++ .../backends/vendor/ascend/impl/activation.py | 38 +++ .../vendor/ascend/impl/normalization.py | 51 ++++ .../backends/vendor/ascend/impl/rotary.py | 91 +++++++ .../backends/vendor/ascend/register_ops.py | 69 +++++ vllm_fl/dispatch/builtin_ops.py | 70 ++++- vllm_fl/dispatch/types.py | 11 + 15 files changed, 658 insertions(+), 132 deletions(-) create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/ascend.py create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py create mode 100644 vllm_fl/dispatch/backends/vendor/ascend/register_ops.py diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 32eb5849..f2874793 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -1,6 +1,6 @@ # Dispatch Mechanism -This directory implements the operator dispatch mechanism for vllm-plugin-FL, providing a flexible operator dispatch system that selects between different backend implementations (FlagGems, PyTorch, etc.) based on availability and policy configuration. +This directory implements the operator dispatch mechanism for vllm-plugin-FL, providing a flexible operator dispatch system that selects between different backend implementations (FlagGems, PyTorch, vendor-specific) based on availability and policy configuration. ## Directory Structure @@ -11,46 +11,29 @@ dispatch/ ├── registry.py # Thread-safe operator registry ├── policy.py # Selection policy management ├── manager.py # Core dispatch manager -├── builtin_ops.py # Built-in operator registration (calls backend register_ops) -├── ops.py # Backend base interface (VLLMFLBackendBase) +├── builtin_ops.py # Built-in operator registration +├── ops.py # Backend base interface ├── discovery.py # Plugin discovery mechanism ├── logger_manager.py # Centralized logging configuration └── backends/ # Backend implementations - ├── __init__.py ├── base.py # Backend abstract base class - ├── flaggems/ # FlagGems backend - │ ├── __init__.py - │ ├── flaggems.py # Backend class - │ ├── register_ops.py # Operator registration - │ └── impl/ # Operator implementations - │ ├── __init__.py - │ ├── activation.py - │ ├── normalization.py - │ └── rotary.py - ├── reference/ # Reference backend (PyTorch) - │ ├── __init__.py - │ ├── reference.py # Backend class - │ ├── register_ops.py # Operator registration - │ └── impl/ # Operator implementations - │ ├── __init__.py - │ ├── activation.py - │ ├── normalization.py - │ └── rotary.py - └── vendor/ # Vendor-specific backends - └── __init__.py # (Add CUDA, etc. as needed) + ├── flaggems/ # FlagGems backend (DEFAULT, priority 150) + ├── reference/ # Reference backend (PyTorch, priority 50) + └── vendor/ # Vendor-specific backends (priority 100) + └── ascend/ # Example: Huawei Ascend backend ``` ## Core Concepts -### 1. Backend Implementation Kind (BackendImplKind) +### 1. Backend Implementation Kind - **DEFAULT**: Default implementation (FlagGems), priority 150 +- **VENDOR**: Vendor-specific implementation (e.g., Ascend), priority 100 - **REFERENCE**: Reference implementation (PyTorch native), priority 50 -- **VENDOR**: Vendor-specific implementation (e.g., CUDA), requires vendor name ### 2. Operator Implementation (OpImpl) -Each operator implementation contains the following attributes: +Each operator implementation contains: - `op_name`: Operator name (e.g., "silu_and_mul", "rmsnorm") - `impl_id`: Unique implementation identifier (e.g., "default.flaggems") - `kind`: Implementation type @@ -58,9 +41,9 @@ Each operator implementation contains the following attributes: - `vendor`: Vendor name (required for VENDOR type) - `priority`: Selection priority (higher value = preferred) -### 3. Selection Policy (SelectionPolicy) +### 3. Selection Policy -Policy controls operator implementation selection behavior: +Policy controls operator implementation selection: - `prefer`: Preferred implementation type - `strict`: Strict mode, whether to raise error when primary implementation fails - `per_op_order`: Custom selection order for each operator @@ -137,12 +120,10 @@ Supports temporary policy override in code: ```python from vllm_fl.dispatch import ( - policy_context, with_strict_mode, with_preference, with_allowed_vendors, with_denied_vendors, - SelectionPolicy, ) # Temporarily enable strict mode @@ -156,26 +137,17 @@ with with_preference("reference"): # Temporarily restrict allowed vendors with with_allowed_vendors("vendor_a"): result = call_op("rotary_embedding", query, key, cos, sin, position_ids) - -# Use custom policy -custom_policy = SelectionPolicy.from_dict( - prefer="flaggems", - strict=True, - deny_vendors={"vendor_x"}, -) -with policy_context(custom_policy): - result = call_op("silu_and_mul", x) ``` ## Supported Operators Currently supported operators: -| Operator | Description | FlagGems | Reference | -|----------|-------------|----------|-----------| -| `silu_and_mul` | SiLU activation + element-wise multiplication | ✓ | ✓ | -| `rmsnorm` | RMS normalization | ✓ | ✓ | -| `rotary_embedding` | Rotary position embedding | ✓ | ✓ | +| Operator | Description | FlagGems | Reference | Vendor | +|----------|-------------|----------|-----------|--------| +| `silu_and_mul` | SiLU activation + element-wise multiplication | ✓ | ✓ | ✓ | +| `rmsnorm` | RMS normalization | ✓ | ✓ | ✓ | +| `rotary_embedding` | Rotary position embedding | ✓ | ✓ | ✓ | ## Selection Process @@ -196,126 +168,176 @@ Op 'rmsnorm' using 'default.flaggems' (kind=flaggems, vendor=None) Op 'rmsnorm' fallback to 'reference.torch' (kind=reference, vendor=None) ``` -## Extending with New Operators +## Extending the System -When adding a new operator (e.g., `layernorm`), modify the following files: +### Adding New Operators -| File | Changes | -|------|---------| -| `backends/flaggems/impl/normalization.py` | Add FlagGems implementation | -| `backends/flaggems/flaggems.py` | Add method to backend class | -| `backends/flaggems/register_ops.py` | Register OpImpl | -| `backends/reference/impl/normalization.py` | Add PyTorch implementation | -| `backends/reference/reference.py` | Add method to backend class | -| `backends/reference/register_ops.py` | Register OpImpl | -| `ops.py` | Add abstract method declaration | +When adding a new operator, modify these files: +- `backends/flaggems/impl/*.py` - Add FlagGems implementation +- `backends/flaggems/flaggems.py` - Add method to backend class +- `backends/flaggems/register_ops.py` - Register OpImpl +- `backends/reference/impl/*.py` - Add PyTorch implementation +- `backends/reference/reference.py` - Add method to backend class +- `backends/reference/register_ops.py` - Register OpImpl +- `ops.py` - Add abstract method declaration -## Extending with New Backends +### Adding Vendor Backends -### 1. Create Backend Directory Structure +The dispatch system supports three ways to integrate vendor backends: +1. **Built-in vendor backends** - Located in `backends/vendor/` (recommended for core vendors) +2. **External plugin packages** - Distributed as separate Python packages +3. **Environment-based plugins** - Loaded via `VLLM_FL_PLUGIN_MODULES` + +#### Option 1: Built-in Vendor Backend + +Directory structure: ``` -backends/my_backend/ +backends/vendor// ├── __init__.py -├── my_backend.py # Backend class -├── register_ops.py # Operator registration -└── impl/ # Operator implementations +├── .py # Backend class +├── register_ops.py # Registration function +└── impl/ # Operator implementations ├── __init__.py ├── activation.py - └── ... + ├── normalization.py + └── rotary.py ``` -### 2. Implement Backend Class +**Step 1: Create Backend Class** (`.py`): ```python -# backends/my_backend/my_backend.py -from ..base import Backend +from ...base import Backend + +class Backend(Backend): + _available = None -class MyBackend(Backend): @property def name(self) -> str: - return "my_backend" + return "" + + @property + def vendor(self) -> str: + return "" # Required for vendor backends def is_available(self) -> bool: - try: - import my_library - return True - except ImportError: - return False + if Backend._available is None: + try: + import + Backend._available = True + except ImportError: + Backend._available = False + return Backend._available def silu_and_mul(self, x): - from .impl.activation import silu_and_mul_my_backend - return silu_and_mul_my_backend(x) + from .impl.activation import silu_and_mul_ + return silu_and_mul_(x) ``` -### 3. Create Registration Module +**Step 2: Create Registration Module** (`register_ops.py`): ```python -# backends/my_backend/register_ops.py -from ...types import OpImpl, BackendImplKind - -def register_builtins(registry) -> None: - from .my_backend import MyBackend +from ....types import OpImpl, BackendImplKind, BackendPriority - backend = MyBackend() - is_avail = backend.is_available +def register_builtins(registry): + from . import Backend + backend = Backend() impls = [ OpImpl( op_name="silu_and_mul", - impl_id="default.my_backend", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.silu_and_mul, is_avail), - priority=100, + impl_id="vendor.", + kind=BackendImplKind.VENDOR, + fn=backend.silu_and_mul, + vendor="", + priority=BackendPriority.VENDOR, # 100 ), ] - registry.register_many(impls) ``` -### 4. Update builtin_ops.py +**Step 3: Register in builtin_ops.py**: ```python -# In builtin_ops.py, add: try: - from .backends.my_backend.register_ops import register_builtins as register_my_backend - register_my_backend(registry) + from .backends.vendor..register_ops import register_builtins as register_ + register_(registry) except Exception as e: - logger.warning(f"Failed to register MyBackend operators: {e}") + logger.debug(f" operators not available: {e}") ``` -## Plugin Discovery - -External plugins can register operators via: +#### Option 2: External Plugin Package -### 1. Entry Points (Recommended) +Create a separate package with entry points: ```python -# In your plugin's setup.py or pyproject.toml -[project.entry-points."vllm_fl.plugin"] -my_plugin = "my_plugin_package:register" +# setup.py +setup( + name="vllm-plugin-", + entry_points={ + "vllm_fl.plugin": [ + " = vllm_fl_.register_ops:register_builtins", + ], + }, +) ``` -```python -# my_plugin_package/__init__.py -def register(registry): - # Register your operators - registry.register_impl(OpImpl(...)) +Install and use: +```bash +pip install vllm-plugin- +# Plugin auto-discovered via entry points ``` -### 2. Environment Variable +#### Option 3: Environment-based Plugin ```bash -export VLLM_FL_PLUGIN_MODULES=my_plugin_module +export VLLM_FL_PLUGIN_MODULES=my_custom_backend.register_ops ``` +The module should provide a `register_builtins(registry)` function. + +#### Priority Levels + +Use constants from `types.py`: +- `BackendPriority.DEFAULT` (150) - FlagGems +- `BackendPriority.VENDOR` (100) - Vendor backends +- `BackendPriority.REFERENCE` (50) - PyTorch + +#### Testing Your Backend + ```python -# my_plugin_module.py -def vllm_fl_register(registry): - # Register your operators - pass +from vllm_fl.dispatch import get_default_manager + +manager = get_default_manager() +manager.ensure_initialized() + +# Check registration +snap = manager.registry.snapshot() +for op_name, impls in snap.impls_by_op.items(): + for impl in impls: + if impl.vendor == "": + print(f"{op_name}: {impl.impl_id}, available={impl.is_available()}") ``` +Enable debug output: +```bash +export VLLM_FL_LOG_LEVEL=DEBUG +``` + +#### Vendor Backend Checklist + +- [ ] Backend class inherits from `Backend` +- [ ] `vendor` property returns vendor name (not None) +- [ ] `is_available()` checks hardware/library availability +- [ ] `register_ops.py` uses `BackendImplKind.VENDOR` +- [ ] `impl_id` follows format: `vendor.` +- [ ] Priority set to `BackendPriority.VENDOR` (100) +- [ ] Error handling for missing dependencies + +#### Current Vendor Backends + +- **Ascend** (Huawei) - Example implementation in `backends/vendor/ascend/` + ## Multi-Process Safety OpManager supports multi-process environments: diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index 5724a90a..0afe1e21 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -32,7 +32,7 @@ - Selected backend for each operator call """ -from .types import OpImpl, BackendImplKind, match_token +from .types import OpImpl, BackendImplKind, BackendPriority, match_token from .registry import OpRegistry, OpRegistrySnapshot from .policy import ( SelectionPolicy, @@ -92,6 +92,7 @@ def resolve_op(op_name: str): # Types "OpImpl", "BackendImplKind", + "BackendPriority", "match_token", # Registry "OpRegistry", diff --git a/vllm_fl/dispatch/backends/__init__.py b/vllm_fl/dispatch/backends/__init__.py index 14a7ea05..84bc334c 100644 --- a/vllm_fl/dispatch/backends/__init__.py +++ b/vllm_fl/dispatch/backends/__init__.py @@ -9,3 +9,17 @@ from .reference import ReferenceBackend __all__ = ["Backend", "FlagGemsBackend", "ReferenceBackend"] + +# Try to import vendor backends +try: + from .vendor.ascend import AscendBackend + __all__.append("AscendBackend") +except ImportError: + AscendBackend = None + +# Add more vendor backends here as they become available +# try: +# from .vendor.cuda import CudaBackend +# __all__.append("CudaBackend") +# except ImportError: +# CudaBackend = None diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index 2c524a05..3b4f3ed3 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -10,7 +10,7 @@ import functools -from ...types import OpImpl, BackendImplKind +from ...types import OpImpl, BackendImplKind, BackendPriority def _bind_is_available(fn, is_available_fn): @@ -44,7 +44,7 @@ def register_builtins(registry) -> None: kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.silu_and_mul, is_avail), vendor=None, - priority=150, + priority=BackendPriority.DEFAULT, ), # Normalization OpImpl( @@ -53,7 +53,7 @@ def register_builtins(registry) -> None: kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.rmsnorm, is_avail), vendor=None, - priority=150, + priority=BackendPriority.DEFAULT, ), # Rotary Embedding OpImpl( @@ -62,7 +62,7 @@ def register_builtins(registry) -> None: kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.rotary_embedding, is_avail), vendor=None, - priority=150, + priority=BackendPriority.DEFAULT, ), ] diff --git a/vllm_fl/dispatch/backends/reference/register_ops.py b/vllm_fl/dispatch/backends/reference/register_ops.py index 59a482f9..e2af7993 100644 --- a/vllm_fl/dispatch/backends/reference/register_ops.py +++ b/vllm_fl/dispatch/backends/reference/register_ops.py @@ -10,7 +10,7 @@ import functools -from ...types import OpImpl, BackendImplKind +from ...types import OpImpl, BackendImplKind, BackendPriority def _bind_is_available(fn, is_available_fn): @@ -44,7 +44,7 @@ def register_builtins(registry) -> None: kind=BackendImplKind.REFERENCE, fn=_bind_is_available(backend.silu_and_mul, is_avail), vendor=None, - priority=50, + priority=BackendPriority.REFERENCE, ), # Normalization OpImpl( @@ -53,7 +53,7 @@ def register_builtins(registry) -> None: kind=BackendImplKind.REFERENCE, fn=_bind_is_available(backend.rmsnorm, is_avail), vendor=None, - priority=50, + priority=BackendPriority.REFERENCE, ), # Rotary Embedding OpImpl( @@ -62,7 +62,7 @@ def register_builtins(registry) -> None: kind=BackendImplKind.REFERENCE, fn=_bind_is_available(backend.rotary_embedding, is_avail), vendor=None, - priority=50, + priority=BackendPriority.REFERENCE, ), ] diff --git a/vllm_fl/dispatch/backends/vendor/__init__.py b/vllm_fl/dispatch/backends/vendor/__init__.py index 8169d3ec..38f6f1af 100644 --- a/vllm_fl/dispatch/backends/vendor/__init__.py +++ b/vllm_fl/dispatch/backends/vendor/__init__.py @@ -4,4 +4,37 @@ Vendor backends for vllm-plugin-FL dispatch. This package contains vendor-specific backend implementations. + +Available vendor backends: +- ascend: Huawei Ascend NPU backend + +To add a new vendor backend: +1. Create a subdirectory: vendor// +2. Implement the backend class inheriting from Backend +3. Create register_ops.py with registration function +4. The backend will be auto-discovered by builtin_ops.py + +See the "Adding Vendor Backends" section in dispatch/README.md for detailed instructions. """ + +__all__ = [] + +# Import Ascend backend +try: + from .ascend import AscendBackend + __all__.append("AscendBackend") +except ImportError: + pass + +# Add more vendor backends here as they become available: +# try: +# from .cuda import CudaBackend +# __all__.append("CudaBackend") +# except ImportError: +# pass +# +# try: +# from .rocm import RocmBackend +# __all__.append("RocmBackend") +# except ImportError: +# pass diff --git a/vllm_fl/dispatch/backends/vendor/ascend/__init__.py b/vllm_fl/dispatch/backends/vendor/ascend/__init__.py new file mode 100644 index 00000000..855c45fc --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend (Huawei) backend for vllm-plugin-FL dispatch. +""" + +from .ascend import AscendBackend + +__all__ = ["AscendBackend"] diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py new file mode 100644 index 00000000..aab90136 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend backend implementation. + +This backend provides operator implementations for Huawei Ascend NPUs. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +from ...base import Backend + + +class AscendBackend(Backend): + """ + Ascend backend for operator implementations. + + This backend uses Ascend CANN libraries to provide high-performance + operator implementations for Huawei Ascend NPUs. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "ascend" + + @property + def vendor(self) -> Optional[str]: + return "ascend" + + def is_available(self) -> bool: + """Check if Ascend hardware and libraries are available.""" + if AscendBackend._available is None: + try: + # Check for torch_npu (Ascend PyTorch extension) + import torch_npu + + # Check if NPU device is available + if torch.npu.is_available() and torch.npu.device_count() > 0: + AscendBackend._available = True + else: + AscendBackend._available = False + except (ImportError, AttributeError): + AscendBackend._available = False + return AscendBackend._available + + # ==================== Operator Implementations ==================== + + def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + from .impl.activation import silu_and_mul_ascend + + return silu_and_mul_ascend(x) + + def rmsnorm( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization. + + Args: + x: Input tensor + residual: Optional residual tensor + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + from .impl.normalization import rmsnorm_ascend + + return rmsnorm_ascend(x, residual, weight, epsilon) + + def rotary_embedding( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + from .impl.rotary import rotary_embedding_ascend + + return rotary_embedding_ascend( + query, + key, + cos, + sin, + position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py new file mode 100644 index 00000000..6a475468 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend operator implementations. +""" + +from .activation import silu_and_mul_ascend +from .normalization import rmsnorm_ascend +from .rotary import rotary_embedding_ascend + +__all__ = [ + "silu_and_mul_ascend", + "rmsnorm_ascend", + "rotary_embedding_ascend", +] diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py new file mode 100644 index 00000000..4434a7fc --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend activation operator implementations. +""" + +from __future__ import annotations + +import torch + + +def silu_and_mul_ascend(x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication using Ascend NPU. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + # Split input into two halves + d = x.shape[-1] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + + # Apply SiLU (x * sigmoid(x)) to first half and multiply with second half + # Use Ascend-optimized operations if available + try: + import torch_npu + + # Use NPU-optimized SiLU if available + silu_out = torch.nn.functional.silu(x1) + return silu_out * x2 + except (ImportError, AttributeError): + # Fallback to standard PyTorch + silu_out = torch.nn.functional.silu(x1) + return silu_out * x2 diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py new file mode 100644 index 00000000..b4dc7815 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend normalization operator implementations. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + + +def rmsnorm_ascend( + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using Ascend NPU. + + Args: + x: Input tensor + residual: Optional residual tensor to add before normalization + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + orig_dtype = x.dtype + x = x.float() + + if residual is not None: + residual = residual.float() + x = x + residual + + # Compute RMS normalization + # RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + + # Apply weight + x = x * weight + + x = x.to(orig_dtype) + + if residual is not None: + return x, residual.to(orig_dtype) + return x diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py new file mode 100644 index 00000000..9866bdc7 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend rotary embedding operator implementations. +""" + +from __future__ import annotations + +import torch + + +def rotary_embedding_ascend( + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using Ascend NPU. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + try: + import torch_npu + + # Use Ascend-optimized rotary embedding if available + # For now, use standard PyTorch implementation + # TODO: Replace with torch_npu.npu_rotary_mul when available + pass + except ImportError: + pass + + # Standard implementation (can be optimized with Ascend kernels) + # Get cos/sin for the positions + if position_ids.dim() == 1: + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + else: + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + + # Expand dimensions to match query/key shape + if query.dim() == 4: + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + elif query.dim() == 3: + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + + # Check if we need to repeat cos/sin to match head_dim + rotary_dim = cos_selected.shape[-1] + head_dim = query.shape[-1] + + if rotary_dim != head_dim: + cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) + sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + if rotary_interleaved: + # Interleaved rotary + def rotate_interleaved(x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + q_embed = (query * cos_selected) + (rotate_interleaved(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_interleaved(key) * sin_selected) + else: + # Standard rotary (neox style) + q_embed = (query * cos_selected) + (rotate_half(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_half(key) * sin_selected) + + return q_embed, k_embed diff --git a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py new file mode 100644 index 00000000..b22a941a --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend backend operator registrations. + +This module registers all VENDOR (Ascend) implementations. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind, BackendPriority + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all Ascend (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + from .ascend import AscendBackend + + backend = AscendBackend() + is_avail = backend.is_available + + impls = [ + # Activation + OpImpl( + op_name="silu_and_mul", + impl_id="vendor.ascend", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + vendor="ascend", + priority=BackendPriority.VENDOR, + ), + # Normalization + OpImpl( + op_name="rmsnorm", + impl_id="vendor.ascend", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm, is_avail), + vendor="ascend", + priority=BackendPriority.VENDOR, + ), + # Rotary Embedding + OpImpl( + op_name="rotary_embedding", + impl_id="vendor.ascend", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rotary_embedding, is_avail), + vendor="ascend", + priority=BackendPriority.VENDOR, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py index 69fe810e..d176fabd 100644 --- a/vllm_fl/dispatch/builtin_ops.py +++ b/vllm_fl/dispatch/builtin_ops.py @@ -9,11 +9,56 @@ from __future__ import annotations +import importlib +import os + from .registry import OpRegistry from .logger_manager import get_logger logger = get_logger() +# Directory containing vendor backends +_VENDOR_BACKENDS_DIR = os.path.join(os.path.dirname(__file__), "backends", "vendor") + + +def _register_vendor_backends(registry: OpRegistry) -> None: + """ + Auto-discover and register all vendor backends. + + Scans the vendor directory for subdirectories containing register_ops.py + and calls their register_builtins function. + + Args: + registry: Registry to register into + """ + if not os.path.isdir(_VENDOR_BACKENDS_DIR): + logger.debug(f"Vendor backends directory not found: {_VENDOR_BACKENDS_DIR}") + return + + for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR): + vendor_path = os.path.join(_VENDOR_BACKENDS_DIR, vendor_name) + + # Skip non-directories and special files + if not os.path.isdir(vendor_path) or vendor_name.startswith("_"): + continue + + # Skip if no register_ops.py exists + register_ops_path = os.path.join(vendor_path, "register_ops.py") + if not os.path.isfile(register_ops_path): + continue + + # Try to import and register + module_name = f".backends.vendor.{vendor_name}.register_ops" + try: + mod = importlib.import_module(module_name, package="vllm_fl.dispatch") + if hasattr(mod, "register_builtins"): + mod.register_builtins(registry) + logger.debug(f"Registered {vendor_name} operators") + else: + logger.debug(f"No register_builtins function in {module_name}") + except Exception as e: + logger.debug(f"{vendor_name} operators not available: {e}") + def register_builtins(registry: OpRegistry) -> None: """ @@ -22,7 +67,8 @@ def register_builtins(registry: OpRegistry) -> None: This function registers: - DEFAULT implementations (FlagGems) - REFERENCE implementations (PyTorch) - - VENDOR implementations (if available) + - VENDOR implementations (auto-discovered) + - External plugins (via entry points and environment variable) Args: registry: Registry to register into @@ -45,14 +91,14 @@ def register_builtins(registry: OpRegistry) -> None: except Exception as e: logger.warning(f"Failed to register Reference operators: {e}") - # Register VENDOR implementations (if available) - # Add vendor backends here as they become available - # Example: - # try: - # from .backends.vendor.cuda.register_ops import register_builtins as register_cuda - # register_cuda(registry) - # logger.debug("Registered CUDA operators") - # except Exception as e: - # # CUDA may not be available, this is expected - # logger.debug(f"CUDA operators not available: {e}") - # pass + # Auto-discover and register VENDOR implementations + _register_vendor_backends(registry) + + # Discover and register external plugins + try: + from .discovery import discover_plugins + plugin_count = discover_plugins(registry) + if plugin_count > 0: + logger.debug(f"Registered {plugin_count} external plugins") + except Exception as e: + logger.debug(f"Plugin discovery failed: {e}") diff --git a/vllm_fl/dispatch/types.py b/vllm_fl/dispatch/types.py index 1afe792f..564b78ca 100644 --- a/vllm_fl/dispatch/types.py +++ b/vllm_fl/dispatch/types.py @@ -27,6 +27,17 @@ def __str__(self) -> str: return self.value +class BackendPriority: + """ + Standard priority values for different backend types. + + Higher priority implementations are selected first when available. + """ + DEFAULT = 150 # Default implementations (FlagGems) + VENDOR = 100 # Vendor-specific implementations + REFERENCE = 50 # Reference implementations (PyTorch, lowest) + + @dataclass(frozen=True) class OpImpl: """ From 754927d31dfd242106963f7111c7846d97d165f1 Mon Sep 17 00:00:00 2001 From: yxa Date: Mon, 19 Jan 2026 09:03:33 +0000 Subject: [PATCH 08/34] Add ascend support --- vllm_fl/__init__.py | 7 + vllm_fl/attention/backends/__init__.py | 41 + vllm_fl/attention/backends/ascend/__init__.py | 40 + .../attention/backends/ascend/attention.py | 710 ++++++++++++++++++ .../backends/ascend/attention_mask.py | 174 +++++ .../dispatch/backends/flaggems/flaggems.py | 14 + .../backends/flaggems/register_ops.py | 9 + .../dispatch/backends/reference/reference.py | 20 + .../backends/reference/register_ops.py | 9 + vllm_fl/dispatch/backends/vendor/__init__.py | 13 +- .../dispatch/backends/vendor/ascend/ascend.py | 21 + .../backends/vendor/ascend/register_ops.py | 9 + .../dispatch/backends/vendor/cuda/__init__.py | 9 + vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 125 +++ .../backends/vendor/cuda/impl/__init__.py | 15 + .../backends/vendor/cuda/impl/activation.py | 37 + .../vendor/cuda/impl/normalization.py | 62 ++ .../backends/vendor/cuda/impl/rotary.py | 95 +++ .../backends/vendor/cuda/register_ops.py | 78 ++ vllm_fl/dispatch/ops.py | 20 + vllm_fl/platform.py | 111 ++- vllm_fl/worker/model_runner.py | 115 ++- vllm_fl/worker/worker.py | 120 ++- 23 files changed, 1791 insertions(+), 63 deletions(-) create mode 100644 vllm_fl/attention/backends/__init__.py create mode 100644 vllm_fl/attention/backends/ascend/__init__.py create mode 100644 vllm_fl/attention/backends/ascend/attention.py create mode 100644 vllm_fl/attention/backends/ascend/attention_mask.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/cuda.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/impl/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py create mode 100644 vllm_fl/dispatch/backends/vendor/cuda/register_ops.py diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 8e46acda..a0735b37 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -1,8 +1,15 @@ # Copyright (c) 2025 BAAI. All rights reserved. + +import os + + def register(): """Register the FL platform.""" + multiproc_method = os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + if multiproc_method is None: + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" return "vllm_fl.platform.PlatformFL" diff --git a/vllm_fl/attention/backends/__init__.py b/vllm_fl/attention/backends/__init__.py new file mode 100644 index 00000000..d5b5aec6 --- /dev/null +++ b/vllm_fl/attention/backends/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Attention backends for vllm-plugin-FL. + +This module provides attention backend implementations for different hardware platforms. +The dispatch mechanism automatically selects the appropriate backend based on the +available hardware and configuration. + +Available backends: +- ascend: Native Ascend NPU attention using torch_npu operators + - Uses torch_npu.npu_fused_infer_attention_score for prefill + - Uses torch_npu._npu_paged_attention for decode + - No dependency on vllm-ascend package +""" + +from vllm_fl.attention.backends.ascend import ( + AscendAttentionBackend, + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendMetadata, + AscendAttentionState, + AscendMLABackend, + AttentionMaskBuilder, + get_attention_mask_builder, + is_torch_npu_available, +) + +__all__ = [ + # Ascend backend + "AscendAttentionBackend", + "AscendAttentionBackendImpl", + "AscendAttentionMetadataBuilder", + "AscendMetadata", + "AscendAttentionState", + "AscendMLABackend", + # Utilities + "AttentionMaskBuilder", + "get_attention_mask_builder", + "is_torch_npu_available", +] diff --git a/vllm_fl/attention/backends/ascend/__init__.py b/vllm_fl/attention/backends/ascend/__init__.py new file mode 100644 index 00000000..41931826 --- /dev/null +++ b/vllm_fl/attention/backends/ascend/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend NPU attention backend for vllm-plugin-FL. + +This package provides native Ascend NPU attention implementation using torch_npu +operators directly, without depending on vllm-ascend package. + +Modules: +- attention: Core attention backend classes (AscendAttentionBackend, etc.) +- attention_mask: Attention mask builder and utilities +""" + +from vllm_fl.attention.backends.ascend.attention import ( + AscendAttentionBackend, + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendMetadata, + AscendAttentionState, + AscendMLABackend, + is_torch_npu_available, +) +from vllm_fl.attention.backends.ascend.attention_mask import ( + AttentionMaskBuilder, + get_attention_mask_builder, +) + +__all__ = [ + # Attention backend classes + "AscendAttentionBackend", + "AscendAttentionBackendImpl", + "AscendAttentionMetadataBuilder", + "AscendMetadata", + "AscendAttentionState", + "AscendMLABackend", + # Utilities + "AttentionMaskBuilder", + "get_attention_mask_builder", + "is_torch_npu_available", +] diff --git a/vllm_fl/attention/backends/ascend/attention.py b/vllm_fl/attention/backends/ascend/attention.py new file mode 100644 index 00000000..e74701f9 --- /dev/null +++ b/vllm_fl/attention/backends/ascend/attention.py @@ -0,0 +1,710 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Ascend NPU native attention backend for vllm-plugin-FL. + +This module provides native Ascend NPU attention implementation using torch_npu +operators directly, without depending on vllm-ascend package. + +Core operators used: +- torch_npu.npu_fused_infer_attention_score: For prefill/chunked-prefill +- torch_npu._npu_paged_attention: For decode +- torch_npu._npu_reshape_and_cache: For KV cache update + +These are optimized operators for Huawei Ascend NPUs that provide better +performance than generic implementations. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionType, +) +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from vllm_fl.attention.backends.ascend.attention_mask import ( + AttentionMaskBuilder, + get_attention_mask_builder, +) + +logger = logging.getLogger(__name__) + +# Check torch_npu availability and setup NPU compatibility +_TORCH_NPU_AVAILABLE = False +try: + import torch_npu + _TORCH_NPU_AVAILABLE = True + + # NPU compatibility: Replace torch.Event and torch.cuda.Stream with NPU versions + # This is similar to vllm-ascend's _torch_cuda_wrapper approach + if hasattr(torch, 'npu') and torch.npu.is_available(): + torch.Event = torch.npu.Event + torch.cuda.Event = torch.npu.Event + torch.cuda.Stream = torch.npu.Stream + logger.info("NPU compatibility enabled: torch.Event -> torch.npu.Event") +except ImportError: + torch_npu = None + logger.warning("torch_npu not available, Ascend attention backend will not work") + + +def is_torch_npu_available() -> bool: + """Check if torch_npu is available.""" + return _TORCH_NPU_AVAILABLE + + +# Ascend platform specific configurations +ASCEND_SAMPLED_TOKEN_IDS_DTYPE = torch.int32 # NPU uses int32, CUDA uses int64 + + +class AscendAttentionState(Enum): + """Attention state for Ascend backend.""" + PrefillNoCache = 0 + PrefillCacheHit = 1 + DecodeOnly = 2 + ChunkedPrefill = 3 + SpecDecoding = 4 + + +@dataclass +class AscendMetadata: + """Metadata for Ascend attention.""" + # Basic properties + attn_mask: Optional[torch.Tensor] = None + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + # Token counts + num_actual_tokens: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 + num_decodes: int = 0 + + # Sequence lengths + seq_lens: torch.Tensor = None + seq_lens_list: List[int] = None + actual_seq_lengths_q: List[int] = None + + query_start_loc: torch.Tensor = None + max_query_len: Optional[int] = None + + # KV Cache properties + block_tables: torch.Tensor = None + slot_mapping: torch.Tensor = None + + causal: bool = True + model_runner_type: str = "" + + +class AscendAttentionMetadataBuilder: + """Builder for Ascend attention metadata.""" + + # ACL graph support - ALWAYS means full graph capture is supported + aclgraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + reorder_batch_threshold: ClassVar[int] = 1 + + # Class-level mask builder cache + _mask_builder: ClassVar[Optional[AttentionMaskBuilder]] = None + _mask_builder_device: ClassVar[Optional[torch.device]] = None + + def __init__( + self, + kv_cache_spec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.max_num_blocks_per_req = cdiv( + self.model_config.max_model_len, + AscendAttentionBackend.get_supported_block_size()[0] + ) + + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + + scheduler_config = vllm_config.scheduler_config + self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill + + def _get_mask_builder(self) -> AttentionMaskBuilder: + """Get or create the attention mask builder (cached at class level).""" + cls = AscendAttentionMetadataBuilder + if cls._mask_builder is None or cls._mask_builder_device != self.device: + cls._mask_builder = AttentionMaskBuilder(self.device) + cls._mask_builder_device = self.device + return cls._mask_builder + + def _make_attention_mask( + self, + attn_state: AscendAttentionState, + ) -> Optional[torch.Tensor]: + """ + Create attention mask based on attention state. + + Args: + attn_state: Current attention state. + + Returns: + Attention mask tensor, or None for decode-only. + """ + # Decode-only doesn't need mask (uses paged attention) + if attn_state == AscendAttentionState.DecodeOnly: + return None + + mask_builder = self._get_mask_builder() + + # Pooling model uses general attention mask + if self.model_config.runner_type == "pooling": + return mask_builder.get_attn_mask(2048, torch.bool) + + # MLA attention + if self.model_config.use_mla: + # TODO: Add pcp_size check if needed + return mask_builder.get_mla_mask(torch.float16) + + # Default: chunked prefill / split-fuse mask + return mask_builder.get_splitfuse_attn_mask() + + def reorder_batch(self, input_batch, scheduler_output) -> bool: + return False + + def build( + self, + common_prefix_len: int, + common_attn_metadata, + model: Optional[nn.Module] = None, + ): + """Build AscendMetadata from common attention metadata.""" + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] + + # Split decodes and prefills + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + self._split_decodes_and_prefills(common_attn_metadata) + + block_table = common_attn_metadata.block_table_tensor + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + + # Determine attention state + attn_state = self._determine_attn_state( + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens + ) + + # Create attention mask based on state + attn_mask = self._make_attention_mask(attn_state) + + query_start_loc = query_start_loc_cpu.pin_memory().to( + self.device, non_blocking=True) + + return AscendMetadata( + num_actual_tokens=num_actual_tokens, + num_decode_tokens=num_decode_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist() if hasattr(seq_lens, 'tolist') else list(seq_lens), + max_query_len=common_attn_metadata.max_query_len, + actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state, + num_prefills=num_prefills, + num_decodes=num_decodes, + causal=getattr(common_attn_metadata, 'causal', True), + model_runner_type=self.model_config.runner_type, + ) + + def _determine_attn_state( + self, + num_decodes: int, + num_prefills: int, + num_decode_tokens: int, + num_prefill_tokens: int, + ) -> AscendAttentionState: + """Determine attention state based on batch composition.""" + if num_prefills == 0: + return AscendAttentionState.DecodeOnly + elif num_decodes == 0 and num_prefill_tokens > 0: + # Pure prefill - check if cache hit or no cache + # For simplicity, use ChunkedPrefill as default + return AscendAttentionState.ChunkedPrefill + else: + # Mixed decode and prefill + return AscendAttentionState.ChunkedPrefill + + def _split_decodes_and_prefills(self, common_attn_metadata): + """Split batch into decode and prefill requests.""" + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= self.decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > self.decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + def build_for_cudagraph_capture( + self, + common_attn_metadata, + model: Optional[nn.Module] = None, + ): + """Build metadata for CUDA graph capture (ACL graph on Ascend).""" + return self.build_for_graph_capture( + common_attn_metadata, + attn_state=AscendAttentionState.DecodeOnly, + model=model, + ) + + def build_for_graph_capture( + self, + common_attn_metadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, + ): + """Build metadata for graph capture.""" + if attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + """ + Cascade attention is not supported for Ascend backend. + + Cascade attention is a CUDA-specific optimization that splits + attention computation for shared prefixes. Ascend NPU uses + different optimizations. + """ + return False + + +class AscendAttentionBackend(AttentionBackend): + """ + Ascend NPU native attention backend. + + Uses torch_npu operators directly for high-performance attention on + Huawei Ascend NPUs. + """ + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_FL" + + @staticmethod + def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: + return AscendAttentionBackendImpl + + @staticmethod + def get_builder_cls() -> Type["AscendAttentionMetadataBuilder"]: + return AscendAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: List[torch.Tensor], + dst_kv_cache: List[torch.Tensor], + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] + dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] + src_indices = src_to_dst[:, 0] + dst_indices = src_to_dst[:, 1] + + dst_key_cache[dst_indices] = src_key_cache[src_indices].to( + dst_key_cache.device) + dst_value_cache[dst_indices] = src_value_cache[src_indices].to( + dst_key_cache.device) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + src_indices = src_to_dists[:, 0] + dst_indices = src_to_dists[:, 1] + + for kv_cache in kv_caches: + key_caches = kv_cache[0] + value_caches = kv_cache[1] + key_caches[dst_indices] = key_caches[src_indices] + value_caches[dst_indices] = value_caches[src_indices] + + @staticmethod + def get_supported_block_size() -> list[int]: + return [128] + + +class AscendAttentionBackendImpl(AttentionImpl): + """ + Ascend attention implementation using native torch_npu operators. + + Core operators: + - torch_npu.npu_fused_infer_attention_score: For prefill attention + - torch_npu._npu_paged_attention: For decode attention + - torch_npu._npu_reshape_and_cache: For KV cache updates + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + if not _TORCH_NPU_AVAILABLE: + raise RuntimeError( + "torch_npu is required for Ascend attention backend. " + "Please install it with: pip install torch_npu" + ) + + self.vllm_config = get_current_vllm_config() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.hidden_size = self.num_heads * self.head_size + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + + if alibi_slopes is not None: + alibi_slopes = torch.tensor( + alibi_slopes, + dtype=torch.float32, + device="npu" + ) + self.alibi_slopes = alibi_slopes + self.attn_type = attn_type + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.key_cache = None + self.value_cache = None + + def _get_fia_params( + self, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + ): + """Get parameters for fused_infer_attention.""" + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + batch_size = attn_metadata.seq_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape + key = self.key_cache.view(num_block, block_size, -1) + value = self.value_cache.view(num_block, block_size, -1) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + num_block, block_size, _, _ = self.key_cache.shape + key = self.key_cache.view(num_block, block_size, -1) + value = self.value_cache.view(num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + else: + # ChunkedPrefill + num_block, block_size, _, _ = self.key_cache.shape + key = self.key_cache.view(num_block, block_size, -1) + value = self.value_cache.view(num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + return key, value, block_size, block_table, actual_seq_lengths_kv + + def reshape_and_cache( + self, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, + ): + """Reshape and cache key/value tensors.""" + if len(kv_cache) > 1: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + # torch_npu requires int32 for slot_indices + # TODO(yxa): block_table.py: CUDA uses int64, NPU uses int32. + if slots.dtype != torch.int32: + slots = slots.to(torch.int32) + # Use torch_npu reshape_and_cache + torch_npu._npu_reshape_and_cache( + key=key[:attn_metadata.num_actual_tokens], + value=value[:attn_metadata.num_actual_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots[:attn_metadata.num_actual_tokens] + ) + return key, value + + def forward_fused_infer_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + """Forward pass using fused_infer_attention_score.""" + key, value, block_size, block_table, actual_seq_lengths_kv = \ + self._get_fia_params(key, value, attn_metadata) + + num_tokens = attn_metadata.actual_seq_lengths_q[-1] + query = query[:num_tokens] + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + key = key[:num_tokens] + value = value[:num_tokens] + + # Determine sparse_mode based on mask availability + # sparse_mode=3 requires attn_mask; sparse_mode=0 does not + # sparse_mode = 3 if attn_metadata.attn_mask is not None else 0 + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + + attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) + output[:num_tokens] = attn_output[:num_tokens] + return output + + def forward_paged_attention( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass using paged attention for decode.""" + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output + ) + return output + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for encoder-only attention.""" + assert attn_metadata is not None + + if attn_metadata.causal: + # Use sparse_mode 3 in causal scenario + return torch_npu.npu_fusion_attention( + query=query, + key=key, + value=value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=3, + atten_mask=attn_metadata.attn_mask, + actual_seq_qlen=attn_metadata.actual_seq_lengths_q, + actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, + )[0] + else: + # Use default sparse_mode 0 in normal scenario + return torch_npu.npu_fusion_attention( + query=query, + key=key, + value=value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + actual_seq_qlen=attn_metadata.actual_seq_lengths_q, + actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, + )[0] + + def forward_impl( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, + output: torch.Tensor, + ): + """Forward implementation dispatching to appropriate attention method.""" + num_tokens = query.shape[0] + + # Use paged attention for decode-only state + if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly + and self.sliding_window is None): + output = self.forward_paged_attention(query, attn_metadata, output) + else: + output = self.forward_fused_infer_attention( + query, key, value, attn_metadata, output) + + return output + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass with Ascend attention. + + Args: + layer: AttentionLayer containing scale factors + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention + output: Pre-allocated output tensor + output_scale: Optional output quantization scale + output_block_scale: Optional output block quantization scale + + Returns: + Output tensor of shape [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "Fused output quantization is not yet supported " + "for AscendAttentionBackendImpl" + ) + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + + attn_type = self.attn_type + if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_ONLY]: + raise NotImplementedError( + "Encoder/Decoder cross-attention is not implemented for " + "AscendAttentionBackendImpl" + ) + + num_tokens = query.shape[0] + if attn_metadata is None: + return output.fill_(0) + + # Reshape and cache KV + key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata) + + # Handle pooling model branch (encoder attention) + if attn_metadata.model_runner_type == "pooling": + attn_output = self._forward_encoder_attention( + query, key, value, attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] + return output + + # Standard forward + output = self.forward_impl( + query, key, value, kv_cache, attn_metadata, output) + return output + + +# MLA Backend placeholder - can be extended later +class AscendMLABackend: + """ + Ascend MLA (Multi-head Latent Attention) backend placeholder. + + This is a minimal implementation. Full MLA support would require + additional implementation based on the specific MLA algorithm. + """ + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Ascend MLA attention backend is not yet fully implemented. " + "Please use standard attention backend by setting use_mla=False" + ) + + +__all__ = [ + "AscendAttentionBackend", + "AscendAttentionBackendImpl", + "AscendAttentionMetadataBuilder", + "AscendMetadata", + "AscendAttentionState", + "AscendMLABackend", + "is_torch_npu_available", +] diff --git a/vllm_fl/attention/backends/ascend/attention_mask.py b/vllm_fl/attention/backends/ascend/attention_mask.py new file mode 100644 index 00000000..ed5884ed --- /dev/null +++ b/vllm_fl/attention/backends/ascend/attention_mask.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Attention mask builder for Ascend NPU backend. + +This module provides utilities for creating and caching attention masks +required by Ascend NPU attention operators. +""" + +from typing import ClassVar, Optional + +import torch + + +class AttentionMaskBuilder: + """ + Builder for creating and caching attention masks. + + This class manages attention mask creation with caching to avoid + redundant tensor allocations. The masks are created lazily and + cached at the class level for reuse across instances. + + Attributes: + device: The device to create masks on. + """ + + # Class-level cache for masks (shared across all instances) + _chunked_prefill_mask: ClassVar[Optional[torch.Tensor]] = None + _chunked_prefill_mask_device: ClassVar[Optional[torch.device]] = None + _mla_mask: ClassVar[Optional[torch.Tensor]] = None + _mla_mask_dtype: ClassVar[Optional[torch.dtype]] = None + _pcp_mla_mask: ClassVar[Optional[torch.Tensor]] = None + _pcp_mla_mask_dtype: ClassVar[Optional[torch.dtype]] = None + + def __init__(self, device: torch.device): + """ + Initialize the attention mask builder. + + Args: + device: The device to create masks on. + """ + self.device = device + + def get_splitfuse_attn_mask(self) -> torch.Tensor: + """ + Get attention mask for split-fuse (chunked prefill) attention. + + Creates a 2048x2048 upper triangular mask with int8 dtype. + The mask is cached and reused for subsequent calls. + + Returns: + Upper triangular attention mask tensor. + """ + cls = AttentionMaskBuilder + if (cls._chunked_prefill_mask is None or + cls._chunked_prefill_mask_device != self.device): + cls._chunked_prefill_mask = torch.triu( + torch.ones(2048, 2048), diagonal=1 + ).to(torch.int8).to(self.device) + cls._chunked_prefill_mask_device = self.device + return cls._chunked_prefill_mask + + def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor: + """ + Get attention mask for MLA (Multi-head Latent Attention). + + Creates a 512x512 upper triangular mask. For fp16, uses + float32 min value as mask; otherwise uses 1. + + Args: + dtype: The dtype for the mask tensor. + + Returns: + MLA attention mask tensor. + """ + cls = AttentionMaskBuilder + if cls._mla_mask is None or cls._mla_mask_dtype != dtype: + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=dtype), 1 + ) + cls._mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype) + cls._mla_mask_dtype = dtype + return cls._mla_mask + + def get_pcp_mla_mask(self, dtype: torch.dtype) -> torch.Tensor: + """ + Get attention mask for PCP (Prefill Context Parallel) MLA. + + Creates a 512x512 upper triangular mask. + + Args: + dtype: The dtype for the mask tensor. + + Returns: + PCP MLA attention mask tensor. + """ + cls = AttentionMaskBuilder + if cls._pcp_mla_mask is None or cls._pcp_mla_mask_dtype != dtype: + cls._pcp_mla_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=dtype), 1 + ) + cls._pcp_mla_mask_dtype = dtype + return cls._pcp_mla_mask + + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype) -> torch.Tensor: + """ + Get a general attention mask for given sequence length. + + Creates a causal mask (lower triangular) for the given sequence length. + + Args: + max_seq_len: Maximum sequence length. + dtype: The dtype for the mask tensor. + + Returns: + Causal attention mask tensor. + """ + # Create lower triangle matrix (True for valid positions) + mask_flag = torch.ones( + (max_seq_len, max_seq_len), dtype=torch.bool + ).tril_() + # Invert to get mask positions (True for masked positions) + mask_flag = ~mask_flag + # For fp16, use -inf; otherwise use 1 + mask_value = float('-inf') if dtype == torch.float16 else 1 + attn_mask = torch.zeros( + size=(max_seq_len, max_seq_len), dtype=dtype + ).masked_fill_(mask_flag, mask_value) + return attn_mask.to(self.device) + + @classmethod + def clear_cache(cls) -> None: + """Clear all cached masks. Useful for testing or memory cleanup.""" + cls._chunked_prefill_mask = None + cls._chunked_prefill_mask_device = None + cls._mla_mask = None + cls._mla_mask_dtype = None + cls._pcp_mla_mask = None + cls._pcp_mla_mask_dtype = None + + +# Global instance cache for convenience +_builder_instance: Optional[AttentionMaskBuilder] = None +_builder_device: Optional[torch.device] = None + + +def get_attention_mask_builder(device: torch.device) -> AttentionMaskBuilder: + """ + Get or create a global AttentionMaskBuilder instance. + + This function provides a convenient way to access the mask builder + without managing instance lifecycle. + + Args: + device: The device for the mask builder. + + Returns: + AttentionMaskBuilder instance. + """ + global _builder_instance, _builder_device + if _builder_instance is None or _builder_device != device: + _builder_instance = AttentionMaskBuilder(device) + _builder_device = device + return _builder_instance + + +__all__ = [ + "AttentionMaskBuilder", + "get_attention_mask_builder", +] diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 05c3f2c0..9b8aab8b 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -115,3 +115,17 @@ def rotary_embedding( rotary_interleaved=rotary_interleaved, inplace=inplace, ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for FlagGems. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string + """ + if use_mla: + return "vllm_fl.attention.mla.MLAFLBackend" + return "vllm_fl.attention.attention.AttentionFLBackend" diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index 3b4f3ed3..0de0307a 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -64,6 +64,15 @@ def register_builtins(registry) -> None: vendor=None, priority=BackendPriority.DEFAULT, ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="default.flaggems", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor=None, + priority=BackendPriority.DEFAULT, + ), ] registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index e32455f2..b1327f8b 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -117,3 +117,23 @@ def rotary_embedding( rotary_interleaved=rotary_interleaved, inplace=inplace, ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for reference (vLLM native). + + This method returns the vLLM native flash attention backend path, + which serves as a fallback implementation. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string (vLLM native backend) + """ + # Return vLLM's native flash attention backend as reference + from vllm.attention.backends.registry import AttentionBackendEnum + if use_mla: + # vLLM native MLA backend + return AttentionBackendEnum.MLA.get_path() + return AttentionBackendEnum.FLASH_ATTN.get_path() diff --git a/vllm_fl/dispatch/backends/reference/register_ops.py b/vllm_fl/dispatch/backends/reference/register_ops.py index e2af7993..fd675b99 100644 --- a/vllm_fl/dispatch/backends/reference/register_ops.py +++ b/vllm_fl/dispatch/backends/reference/register_ops.py @@ -64,6 +64,15 @@ def register_builtins(registry) -> None: vendor=None, priority=BackendPriority.REFERENCE, ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="reference.torch", + kind=BackendImplKind.REFERENCE, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor=None, + priority=BackendPriority.REFERENCE, + ), ] registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/vendor/__init__.py b/vllm_fl/dispatch/backends/vendor/__init__.py index 38f6f1af..14260149 100644 --- a/vllm_fl/dispatch/backends/vendor/__init__.py +++ b/vllm_fl/dispatch/backends/vendor/__init__.py @@ -26,14 +26,15 @@ except ImportError: pass +# Import CUDA backend +try: + from .cuda import CudaBackend + __all__.append("CudaBackend") +except ImportError: + pass + # Add more vendor backends here as they become available: # try: -# from .cuda import CudaBackend -# __all__.append("CudaBackend") -# except ImportError: -# pass -# -# try: # from .rocm import RocmBackend # __all__.append("RocmBackend") # except ImportError: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index aab90136..8801e8f1 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -124,3 +124,24 @@ def rotary_embedding( rotary_interleaved=rotary_interleaved, inplace=inplace, ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for Ascend NPU. + + This method returns the native Ascend attention backend that uses + torch_npu operators (npu_fused_infer_attention_score, etc.) + instead of flag_gems operators. + + Uses vllm_fl's native Ascend implementation which directly calls + torch_npu operators without depending on vllm-ascend package. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string + """ + if use_mla: + return "vllm_fl.attention.backends.ascend.AscendMLABackend" + return "vllm_fl.attention.backends.ascend.AscendAttentionBackend" diff --git a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py index b22a941a..04cad41d 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py @@ -64,6 +64,15 @@ def register_builtins(registry) -> None: vendor="ascend", priority=BackendPriority.VENDOR, ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="vendor.ascend", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor="ascend", + priority=BackendPriority.VENDOR, + ), ] registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/vendor/cuda/__init__.py b/vllm_fl/dispatch/backends/vendor/cuda/__init__.py new file mode 100644 index 00000000..692a74d0 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA backend for vllm-plugin-FL dispatch. +""" + +from .cuda import CudaBackend + +__all__ = ["CudaBackend"] diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py new file mode 100644 index 00000000..fdc341a0 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -0,0 +1,125 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA backend implementation. + +This backend provides operator implementations for NVIDIA CUDA GPUs. +""" + +from __future__ import annotations + +import os +from typing import Optional, Union + +import torch + +from ...base import Backend + + +class CudaBackend(Backend): + """ + CUDA backend for operator implementations. + + This backend uses CUDA libraries to provide high-performance + operator implementations for NVIDIA GPUs. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "cuda" + + @property + def vendor(self) -> Optional[str]: + return "cuda" + + def is_available(self) -> bool: + """Check if CUDA hardware and libraries are available.""" + if CudaBackend._available is None: + try: + # Check if CUDA device is available + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + CudaBackend._available = True + else: + CudaBackend._available = False + except Exception: + CudaBackend._available = False + return CudaBackend._available + + # ==================== Operator Implementations ==================== + + def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Uses vLLM's native CUDA implementation. + """ + from .impl.activation import silu_and_mul_cuda + + return silu_and_mul_cuda(x) + + def rmsnorm( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using vLLM's CUDA implementation. + """ + from .impl.normalization import rmsnorm_cuda + + return rmsnorm_cuda(x, residual, weight, epsilon) + + def rotary_embedding( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using vLLM's CUDA implementation. + """ + from .impl.rotary import rotary_embedding_cuda + + return rotary_embedding_cuda( + query, + key, + cos, + sin, + position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for CUDA. + + Supports: + - FLASH_ATTN (default) + - TRITON_ATTN (when USE_FLAGGEMS=1) + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string + """ + from vllm.attention.backends.registry import AttentionBackendEnum + + if use_mla: + return AttentionBackendEnum.MLA.get_path() + + # Check for TRITON_ATTN preference via environment variable + if os.environ.get("USE_FLAGGEMS", "0") == "1": + return AttentionBackendEnum.TRITON_ATTN.get_path() + + # Default to FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN.get_path() diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/__init__.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/__init__.py new file mode 100644 index 00000000..24e6f56d --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA operator implementations. +""" + +from .activation import silu_and_mul_cuda +from .normalization import rmsnorm_cuda +from .rotary import rotary_embedding_cuda + +__all__ = [ + "silu_and_mul_cuda", + "rmsnorm_cuda", + "rotary_embedding_cuda", +] diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py new file mode 100644 index 00000000..c6648bfa --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA activation operator implementations. +""" + +from __future__ import annotations + +import torch + + +def silu_and_mul_cuda(x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication using CUDA. + + Uses vLLM's optimized CUDA kernel when available. + + Args: + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + d = x.shape[-1] // 2 + out = torch.empty(*x.shape[:-1], d, dtype=x.dtype, device=x.device) + + try: + from vllm._custom_ops import silu_and_mul as vllm_silu_and_mul + vllm_silu_and_mul(out, x) + except ImportError: + # Fallback to standard PyTorch + x1 = x[..., :d] + x2 = x[..., d:] + silu_out = torch.nn.functional.silu(x1) + out = silu_out * x2 + + return out diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py new file mode 100644 index 00000000..b4de3114 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA normalization operator implementations. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + + +def rmsnorm_cuda( + x: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + epsilon: float, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using CUDA. + + Uses vLLM's optimized CUDA kernel when available. + + Args: + x: Input tensor + residual: Optional residual tensor to add before normalization + weight: Normalization weight + epsilon: Small constant for numerical stability + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + try: + from vllm._custom_ops import rms_norm as vllm_rms_norm + from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm + + if residual is not None: + vllm_fused_add_rms_norm(x, residual, weight, epsilon) + return x, residual + else: + out = torch.empty_like(x) + vllm_rms_norm(out, x, weight, epsilon) + return out + + except ImportError: + # Fallback to standard PyTorch + orig_dtype = x.dtype + x = x.float() + + if residual is not None: + residual = residual.float() + x = x + residual + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + x = x * weight + x = x.to(orig_dtype) + + if residual is not None: + return x, residual.to(orig_dtype) + return x diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py new file mode 100644 index 00000000..a54c5529 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA rotary embedding operator implementations. +""" + +from __future__ import annotations + +import torch + + +def rotary_embedding_cuda( + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using CUDA. + + Uses vLLM's optimized CUDA kernel when available. + + Args: + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + try: + from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding + + # vLLM's rotary_embedding modifies tensors in-place + vllm_rotary_embedding( + position_ids, + query, + key, + cos, + sin, + rotary_interleaved, + ) + return query, key + + except ImportError: + # Fallback to standard PyTorch implementation + if position_ids.dim() == 1: + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + else: + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + + # Expand dimensions to match query/key shape + if query.dim() == 4: + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + elif query.dim() == 3: + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + + # Check if we need to repeat cos/sin to match head_dim + rotary_dim = cos_selected.shape[-1] + head_dim = query.shape[-1] + + if rotary_dim != head_dim: + cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) + sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + if rotary_interleaved: + def rotate_interleaved(x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + q_embed = (query * cos_selected) + (rotate_interleaved(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_interleaved(key) * sin_selected) + else: + q_embed = (query * cos_selected) + (rotate_half(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_half(key) * sin_selected) + + return q_embed, k_embed diff --git a/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py b/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py new file mode 100644 index 00000000..f4c2d269 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +CUDA backend operator registrations. + +This module registers all VENDOR (CUDA) implementations. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind, BackendPriority + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all CUDA (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + from .cuda import CudaBackend + + backend = CudaBackend() + is_avail = backend.is_available + + impls = [ + # Activation + OpImpl( + op_name="silu_and_mul", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + vendor="cuda", + priority=BackendPriority.VENDOR, + ), + # Normalization + OpImpl( + op_name="rmsnorm", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm, is_avail), + vendor="cuda", + priority=BackendPriority.VENDOR, + ), + # Rotary Embedding + OpImpl( + op_name="rotary_embedding", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rotary_embedding, is_avail), + vendor="cuda", + priority=BackendPriority.VENDOR, + ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor="cuda", + priority=BackendPriority.VENDOR, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/ops.py b/vllm_fl/dispatch/ops.py index b2b87389..eacab390 100644 --- a/vllm_fl/dispatch/ops.py +++ b/vllm_fl/dispatch/ops.py @@ -125,3 +125,23 @@ def rotary_embedding( Tuple of (embedded_query, embedded_key) """ pass + + # ==================== Attention Backend ==================== + + @abstractmethod + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for this platform. + + This method returns the fully qualified class path of the attention + backend implementation suitable for the current hardware platform. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string, e.g.: + - "vllm_fl.attention.backends.ascend.AscendAttentionBackend" + - "vllm_fl.attention.attention.AttentionFLBackend" + """ + pass diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index f1436a49..9d82c014 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -122,7 +122,13 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + # Ascend NPU requires block_size to be a multiple of 128 + # CUDA can use smaller block sizes like 16 + if cls.device_type == "npu": + cache_config.block_size = 128 + logger.info("Setting kv cache block size to 128 for Ascend NPU.") + else: + cache_config.block_size = 16 # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing @@ -168,43 +174,70 @@ def get_attn_backend_cls( selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", ) -> list[str]: - # from vllm_fl.attention.custom_attention import register_attention - # register_attention() - device_capability = cls.get_device_capability() - - if selected_backend is None: - backend = _get_backend( - use_mla=False, - device_info=cls.device_info, - )[0] # get the highest priority backend - else: - backend = selected_backend - - backend_class = backend.get_class() - invalid_reasons = backend_class.validate_configuration( - device_capability=device_capability, - **attn_selector_config._asdict(), - ) - reasons_str = ( - "{" - + ", ".join( - f"{backend.name}: [{', '.join(invalid_reasons)}]" + """ + Get the attention backend class path using the dispatch mechanism. + + The dispatch mechanism automatically selects the appropriate backend + based on: + 1. Hardware availability (NPU, CUDA, etc.) + 2. User configuration via environment variables: + - VLLM_FL_PREFER: "vendor" | "reference" | "flaggems" + - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors + - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors + + Backend selection order (default): + - flaggems: FlagGems attention (may have issues) + - vendor: Platform-specific implementation (e.g., Ascend NPU) + - reference: vLLM native attention (PyTorch-based) + + For NPU users, recommend setting VLLM_FL_PREFER=vendor to use + Ascend-optimized attention backend. + """ + from vllm_fl.dispatch import call_op + + use_mla = attn_selector_config.use_mla + + try: + # Use dispatch mechanism to select the appropriate attention backend + # The dispatch system will automatically choose based on: + # 1. Backend availability (is_available check) + # 2. Policy configuration (VLLM_FL_PREFER, etc.) + # 3. Priority ordering (vendor > reference when vendor is preferred) + backend_path = call_op("attention_backend", use_mla=use_mla) + + logger.info_once( + "Using attention backend via dispatch (use_mla=%s): %s", + use_mla, backend_path, + scope="local", + ) + return backend_path + + except RuntimeError as e: + # Fallback: if dispatch fails, use device-type based selection + logger.warning( + "Dispatch mechanism failed for attention_backend, " + "falling back to device-type based selection: %s", e + ) + + if cls.device_type == "npu": + if use_mla: + backend_path = "vllm_fl.attention.mla.MLAFLBackend" + else: + backend_path = "vllm_fl.attention.attention.AttentionFLBackend" + else: + # For CUDA and other devices, use vLLM native backend + from vllm.attention.backends.registry import AttentionBackendEnum + if use_mla: + backend_path = AttentionBackendEnum.MLA.get_path() + else: + backend_path = AttentionBackendEnum.FLASH_ATTN.get_path() + + logger.info_once( + "Using fallback attention backend (use_mla=%s): %s", + use_mla, backend_path, + scope="local", ) - + "}" - ) - config_str = attn_selector_config.__repr__() - logger.debug_once( - f"Some attention backends are not valid for {cls.device_name} with " - f"{config_str}. Reasons: {reasons_str}." - ) - - logger.info_once( - "Using %s attention backend out of potential backends: %s", - backend.name, - tuple(backend.name), - scope="local", - ) - return backend.get_path() + return backend_path @classmethod def get_vit_attn_backend( @@ -287,6 +320,10 @@ def use_custom_allreduce(cls) -> bool: @classmethod def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + # TODO(yxa): For NPU/Ascend devices, return None (no capability version like CUDA) + if cls.device_type == "npu": + return None + # For CUDA devices major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index febedc1d..7df9dd8b 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -217,6 +217,45 @@ def graph_capture(device: torch.device): logger = init_logger(__name__) +def _get_kv_cache_shape_compat( + attn_backend: type, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", +) -> tuple[int, ...]: + """ + Compatibility wrapper for get_kv_cache_shape that handles different backend signatures. + + vLLM base class signature includes cache_dtype_str parameter, but some backends + (e.g., Ascend) don't support this parameter. This function tries the full signature + first, then falls back to the basic signature if TypeError is raised. + """ + import inspect + + # Check if the backend's get_kv_cache_shape accepts cache_dtype_str + sig = inspect.signature(attn_backend.get_kv_cache_shape) + params = sig.parameters + + if 'cache_dtype_str' in params: + # Backend supports cache_dtype_str parameter + return attn_backend.get_kv_cache_shape( + num_blocks, + block_size, + num_kv_heads, + head_size, + cache_dtype_str=cache_dtype_str, + ) + else: + # Backend doesn't support cache_dtype_str, use basic signature + return attn_backend.get_kv_cache_shape( + num_blocks, + block_size, + num_kv_heads, + head_size, + ) + AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict @@ -615,9 +654,11 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.Event() + # TODO(yxa): NPU uses int32, CUDA uses int64 for sampled token ids + sampled_ids_dtype = torch.int32 if current_platform.device_type == "npu" else torch.int64 self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), - dtype=torch.int64, + dtype=sampled_ids_dtype, device="cpu", pin_memory=self.pin_memory, ) @@ -3672,16 +3713,28 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - except current_platform.torch_device_fn.OutOfMemoryError as e: - msg = ( - "Failed to load model - not enough GPU memory. " - "Try lowering --gpu-memory-utilization to free memory for weights, " - "increasing --tensor-parallel-size, or using --quantization. " - "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " - "for more tips." - ) - combined_msg = f"{msg} (original error: {e})" - logger.error(combined_msg) + # TODO(yxa): AttributeError: module 'torch_npu.npu' has no attribute 'OutOfMemoryError' + except Exception as e: + # Check if this is an OutOfMemoryError + # For CUDA: torch.cuda.OutOfMemoryError + # For NPU: torch_npu doesn't have OutOfMemoryError, but OOM raises RuntimeError + is_oom = False + if hasattr(current_platform.torch_device_fn, 'OutOfMemoryError'): + is_oom = isinstance(e, current_platform.torch_device_fn.OutOfMemoryError) + else: + # For NPU or other backends without OutOfMemoryError, check error message + is_oom = isinstance(e, RuntimeError) and 'out of memory' in str(e).lower() + + if is_oom: + msg = ( + "Failed to load model - not enough device memory. " + "Try lowering --gpu-memory-utilization to free memory for weights, " + "increasing --tensor-parallel-size, or using --quantization. " + "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + "for more tips." + ) + combined_msg = f"{msg} (original error: {e})" + logger.error(combined_msg) raise e logger.info_once( "Model loading took %.4f GiB memory and %.6f seconds", @@ -4791,6 +4844,37 @@ def initialize_metadata_builders( # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() + def _get_graph_support( + self, + builder_cls: type, + vllm_config: VllmConfig, + kv_cache_spec: Any, + ) -> AttentionCGSupport: + """ + Get graph support (CUDA Graph or ACL Graph) from builder class. + + Different backends use different approaches: + - NVIDIA/CUDA: get_cudagraph_support() method + - Ascend: aclgraph_support class variable + + This helper abstracts the difference for multi-backend support. + """ + # Try CUDA Graph support (NVIDIA/CUDA backends) + if hasattr(builder_cls, 'get_cudagraph_support'): + return builder_cls.get_cudagraph_support(vllm_config, kv_cache_spec) + + # Try ACL Graph support (Ascend backend) + if hasattr(builder_cls, 'aclgraph_support'): + return builder_cls.aclgraph_support + + # Default: no graph support + logger.warning( + f"Builder class {builder_cls.__name__} has no graph support " + "(neither get_cudagraph_support method nor aclgraph_support attribute). " + "Defaulting to AttentionCGSupport.NEVER." + ) + return AttentionCGSupport.NEVER + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], @@ -4811,8 +4895,8 @@ def _check_and_update_cudagraph_mode( for attn_backend in attn_backend_set: builder_cls = attn_backend.get_builder_cls() - cg_support = builder_cls.get_cudagraph_support( - self.vllm_config, kv_cache_group.kv_cache_spec + cg_support = self._get_graph_support( + builder_cls, self.vllm_config, kv_cache_group.kv_cache_spec ) if cg_support.value < min_cg_support.value: min_cg_support = cg_support @@ -5204,12 +5288,13 @@ def _reshape_kv_cache_tensors( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - kv_cache_shape = attn_backend.get_kv_cache_shape( + kv_cache_shape = _get_kv_cache_shape_compat( + attn_backend, kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype, + self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype try: diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index d3742c17..740c8388 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -7,9 +7,10 @@ import copy import gc import os -from contextlib import AbstractContextManager, nullcontext +from contextlib import AbstractContextManager, nullcontext, contextmanager from types import NoneType -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast, Generator, Union +from dataclasses import dataclass import numpy as np import torch @@ -43,7 +44,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask -from vllm.utils.mem_utils import GiB_bytes, MemorySnapshot, memory_profiling +from vllm.utils.mem_utils import GiB_bytes#, MemorySnapshot, memory_profiling from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -61,6 +62,113 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm_fl.worker.model_runner import ModelRunnerFL +@dataclass +class MemorySnapshot: + """Platform-agnostic memory snapshot for FL worker.""" + torch_peak: int = 0 + free_memory: int = 0 + total_memory: int = 0 + cuda_memory: int = 0 + torch_memory: int = 0 + non_torch_memory: int = 0 + timestamp: float = 0.0 + auto_measure: bool = True + + def __post_init__(self): + if self.auto_measure: + self.measure() + + def measure(self): + import time + torch_device_fn = current_platform.torch_device_fn + + # Get peak memory stats using platform-agnostic API + try: + self.torch_peak = torch_device_fn.memory_stats().get( + "allocated_bytes.all.peak", 0) + except (AttributeError, RuntimeError): + self.torch_peak = 0 + + # Get free and total memory using platform-agnostic API + self.free_memory, self.total_memory = torch_device_fn.mem_get_info() + self.cuda_memory = self.total_memory - self.free_memory + + # Get torch reserved memory + try: + self.torch_memory = torch_device_fn.memory_reserved() + except (AttributeError, RuntimeError): + self.torch_memory = 0 + + self.non_torch_memory = self.cuda_memory - self.torch_memory + self.timestamp = time.time() + + def __sub__(self, other: 'MemorySnapshot') -> 'MemorySnapshot': + result = MemorySnapshot(auto_measure=False) + result.torch_peak = self.torch_peak - other.torch_peak + result.free_memory = self.free_memory - other.free_memory + result.total_memory = self.total_memory + result.cuda_memory = self.cuda_memory - other.cuda_memory + result.torch_memory = self.torch_memory - other.torch_memory + result.non_torch_memory = self.non_torch_memory - other.non_torch_memory + result.timestamp = self.timestamp - other.timestamp + return result + + +@dataclass +class MemoryProfilingResult: + """Platform-agnostic memory profiling result.""" + before_create: MemorySnapshot = None + before_profile: MemorySnapshot = None + after_profile: MemorySnapshot = None + weights_memory: int = 0 + torch_peak_increase: int = 0 + non_torch_increase: int = 0 + non_kv_cache_memory: int = 0 + profile_time: float = 0.0 + + def __post_init__(self): + if self.before_profile is None: + self.before_profile = MemorySnapshot(auto_measure=False) + if self.after_profile is None: + self.after_profile = MemorySnapshot(auto_measure=False) + + +@contextmanager +def memory_profiling_fl( + baseline_snapshot: MemorySnapshot, + weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + """Platform-agnostic memory profiling context manager for FL worker.""" + gc.collect() + torch_device_fn = current_platform.torch_device_fn + torch_device_fn.empty_cache() + + # Reset peak memory stats - platform agnostic + try: + torch_device_fn.reset_peak_memory_stats() + except (AttributeError, RuntimeError): + pass # Some platforms may not support this + + result = MemoryProfilingResult() + result.before_create = baseline_snapshot + result.weights_memory = weights_memory + result.before_profile.measure() + + yield result + + gc.collect() + torch_device_fn.empty_cache() + + result.after_profile.measure() + + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + result.torch_peak_increase = diff_profile.torch_peak + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory class WorkerFL(WorkerBase): @@ -190,8 +298,10 @@ def init_device(self): # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK self.local_rank += dp_local_rank * tp_pp_world_size - assert self.local_rank < torch.cuda.device_count(), ( + device_count = current_platform.torch_device_fn.device_count() if current_platform.torch_device_fn.is_available() else 0 + assert self.local_rank < device_count, ( f"DP adjusted local rank {self.local_rank} is out of bounds. " + f"Device count: {device_count}" ) visible_device_count = ( current_platform.torch_device_fn.device_count() if current_platform.torch_device_fn.is_available() else 0 @@ -306,7 +416,7 @@ def determine_available_memory(self) -> int: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - with memory_profiling( + with memory_profiling_fl( self.init_snapshot, weights_memory=int(self.model_runner.model_memory_usage), ) as profile_result: From 575621d01405c639efca523836c495ca47659813 Mon Sep 17 00:00:00 2001 From: yxa Date: Mon, 19 Jan 2026 10:31:23 +0000 Subject: [PATCH 09/34] Delete unnecessary files. --- .../E=512,N=128,device_name=cuda.json | 147 ------------------ 1 file changed, 147 deletions(-) delete mode 100644 examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json diff --git a/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json b/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json deleted file mode 100644 index dc478aaa..00000000 --- a/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - } -} - From 9e8553687bd33f599337d747ca1fdbcd6f059690 Mon Sep 17 00:00:00 2001 From: yxa Date: Mon, 19 Jan 2026 10:38:29 +0000 Subject: [PATCH 10/34] Modify copyright information --- vllm_fl/attention/backends/__init__.py | 2 +- vllm_fl/attention/backends/ascend/__init__.py | 2 +- vllm_fl/attention/backends/ascend/attention.py | 2 +- vllm_fl/attention/backends/ascend/attention_mask.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_fl/attention/backends/__init__.py b/vllm_fl/attention/backends/__init__.py index d5b5aec6..4bd98399 100644 --- a/vllm_fl/attention/backends/__init__.py +++ b/vllm_fl/attention/backends/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 BAAI. All rights reserved. +# Copyright (c) 2026 BAAI. All rights reserved. """ Attention backends for vllm-plugin-FL. diff --git a/vllm_fl/attention/backends/ascend/__init__.py b/vllm_fl/attention/backends/ascend/__init__.py index 41931826..d94ea445 100644 --- a/vllm_fl/attention/backends/ascend/__init__.py +++ b/vllm_fl/attention/backends/ascend/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 BAAI. All rights reserved. +# Copyright (c) 2026 BAAI. All rights reserved. """ Ascend NPU attention backend for vllm-plugin-FL. diff --git a/vllm_fl/attention/backends/ascend/attention.py b/vllm_fl/attention/backends/ascend/attention.py index e74701f9..21b5baf0 100644 --- a/vllm_fl/attention/backends/ascend/attention.py +++ b/vllm_fl/attention/backends/ascend/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 BAAI. All rights reserved. +# Copyright (c) 2026 BAAI. All rights reserved. """ Ascend NPU native attention backend for vllm-plugin-FL. diff --git a/vllm_fl/attention/backends/ascend/attention_mask.py b/vllm_fl/attention/backends/ascend/attention_mask.py index ed5884ed..0fc3d043 100644 --- a/vllm_fl/attention/backends/ascend/attention_mask.py +++ b/vllm_fl/attention/backends/ascend/attention_mask.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 BAAI. All rights reserved. +# Copyright (c) 2026 BAAI. All rights reserved. """ Attention mask builder for Ascend NPU backend. From 676cdd2153c7a733fcd9890863ab871df66caba3 Mon Sep 17 00:00:00 2001 From: yxa Date: Mon, 19 Jan 2026 11:25:16 +0000 Subject: [PATCH 11/34] Modify the directory name --- .../E=512,N=128,device_name=cuda.json | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{fuse_moe_fuse => fuse_moe_tune}/E=512,N=128,device_name=cuda.json (100%) diff --git a/examples/fuse_moe_fuse/E=512,N=128,device_name=cuda.json b/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json similarity index 100% rename from examples/fuse_moe_fuse/E=512,N=128,device_name=cuda.json rename to examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json From ab5f772e6ab4730b82cee59be2a9cef6870035d8 Mon Sep 17 00:00:00 2001 From: yxa Date: Tue, 20 Jan 2026 07:10:01 +0000 Subject: [PATCH 12/34] Make modifications and adjustments based on the PR (pull request) feedback. --- .../attention/backends/ascend/attention.py | 10 +++ .../backends/vendor/ascend/impl/activation.py | 20 +---- .../vendor/ascend/impl/normalization.py | 20 +---- .../backends/vendor/ascend/impl/rotary.py | 63 +++++-------- .../backends/vendor/cuda/impl/activation.py | 14 +-- .../vendor/cuda/impl/normalization.py | 39 +++----- .../backends/vendor/cuda/impl/rotary.py | 69 +++----------- vllm_fl/platform.py | 25 +----- vllm_fl/worker/model_runner.py | 90 ++----------------- 9 files changed, 71 insertions(+), 279 deletions(-) diff --git a/vllm_fl/attention/backends/ascend/attention.py b/vllm_fl/attention/backends/ascend/attention.py index 21b5baf0..6ae158c3 100644 --- a/vllm_fl/attention/backends/ascend/attention.py +++ b/vllm_fl/attention/backends/ascend/attention.py @@ -1,4 +1,8 @@ # Copyright (c) 2026 BAAI. All rights reserved. +# Adapted from https://github.com/vllm-project/vllm-ascend/blob/v0.13.0rc1/vllm_ascend/attention/attention_v1.py +# Below is the original copyright: +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright (c) 2025 Huawei Technologies Co., Ltd. """ Ascend NPU native attention backend for vllm-plugin-FL. @@ -115,6 +119,11 @@ class AscendAttentionMetadataBuilder: aclgraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS reorder_batch_threshold: ClassVar[int] = 1 + @staticmethod + def get_cudagraph_support(vllm_config, kv_cache_spec) -> AttentionCGSupport: + """Get CUDAGraph support level for Ascend backend.""" + return AttentionCGSupport.ALWAYS + # Class-level mask builder cache _mask_builder: ClassVar[Optional[AttentionMaskBuilder]] = None _mask_builder_device: ClassVar[Optional[torch.device]] = None @@ -343,6 +352,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py index 0975b093..320b04c1 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py @@ -19,20 +19,6 @@ def silu_and_mul_ascend(x: torch.Tensor) -> torch.Tensor: Returns: Output tensor of shape [..., d] """ - # Split input into two halves - d = x.shape[-1] // 2 - x1 = x[..., :d] - x2 = x[..., d:] - - # Apply SiLU (x * sigmoid(x)) to first half and multiply with second half - # Use Ascend-optimized operations if available - try: - import torch_npu - - # Use NPU-optimized SiLU if available - silu_out = torch.nn.functional.silu(x1) - return silu_out * x2 - except (ImportError, AttributeError): - # Fallback to standard PyTorch - silu_out = torch.nn.functional.silu(x1) - return silu_out * x2 + import torch_npu + + return torch_npu.npu_swiglu(x) diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py index 756faff4..022c4b47 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -29,23 +29,11 @@ def rmsnorm_ascend( Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ - orig_dtype = x.dtype - x = x.float() + import torch_npu if residual is not None: - residual = residual.float() - x = x + residual + x, _, residual = torch_npu.npu_add_rms_norm(x, residual, weight, epsilon) + return x, residual - # Compute RMS normalization - # RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + epsilon) - - # Apply weight - x = x * weight - - x = x.to(orig_dtype) - - if residual is not None: - return x, residual.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, weight, epsilon) return x diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py index 9790ea7a..443dd562 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -33,17 +33,8 @@ def rotary_embedding_ascend( Returns: Tuple of (embedded_query, embedded_key) """ - try: - import torch_npu + import torch_npu - # Use Ascend-optimized rotary embedding if available - # For now, use standard PyTorch implementation - # TODO: Replace with torch_npu.npu_rotary_mul when available - pass - except ImportError: - pass - - # Standard implementation (can be optimized with Ascend kernels) # Get cos/sin for the positions if position_ids.dim() == 1: cos_selected = cos[position_ids] @@ -52,40 +43,34 @@ def rotary_embedding_ascend( cos_selected = cos[position_ids] sin_selected = sin[position_ids] - # Expand dimensions to match query/key shape - if query.dim() == 4: - cos_selected = cos_selected.unsqueeze(1) - sin_selected = sin_selected.unsqueeze(1) - elif query.dim() == 3: - cos_selected = cos_selected.unsqueeze(1) - sin_selected = sin_selected.unsqueeze(1) - - # Check if we need to repeat cos/sin to match head_dim - rotary_dim = cos_selected.shape[-1] + # Prepare cos/sin shape for npu_rotary_mul: [1, seq_len, 1, head_dim] head_dim = query.shape[-1] + rotary_dim = cos_selected.shape[-1] + # Duplicate cos/sin if needed to match head_dim if rotary_dim != head_dim: cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - if rotary_interleaved: - # Interleaved rotary - def rotate_interleaved(x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - return torch.stack((-x2, x1), dim=-1).flatten(-2) - - q_embed = (query * cos_selected) + (rotate_interleaved(query) * sin_selected) - k_embed = (key * cos_selected) + (rotate_interleaved(key) * sin_selected) - else: - # Standard rotary (neox style) - q_embed = (query * cos_selected) + (rotate_half(query) * sin_selected) - k_embed = (key * cos_selected) + (rotate_half(key) * sin_selected) + # Reshape cos/sin to [1, seq_len, 1, head_dim] + cos_selected = cos_selected.reshape(1, -1, 1, head_dim) + sin_selected = sin_selected.reshape(1, -1, 1, head_dim) + + # Reshape query/key to [1, seq_len, num_heads, head_dim] + query_shape = query.shape + key_shape = key.shape + + if query.dim() == 3: + query = query.unsqueeze(0) + if key.dim() == 3: + key = key.unsqueeze(0) + + # Apply rotary embedding using NPU kernel + q_embed = torch_npu.npu_rotary_mul(query, cos_selected, sin_selected) + k_embed = torch_npu.npu_rotary_mul(key, cos_selected, sin_selected) + + # Restore original shape + q_embed = q_embed.view(query_shape) + k_embed = k_embed.view(key_shape) return q_embed, k_embed diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index c3d9d50c..dc89b4c3 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -21,17 +21,9 @@ def silu_and_mul_cuda(x: torch.Tensor) -> torch.Tensor: Returns: Output tensor of shape [..., d] """ + from vllm._custom_ops import silu_and_mul as vllm_silu_and_mul + d = x.shape[-1] // 2 out = torch.empty(*x.shape[:-1], d, dtype=x.dtype, device=x.device) - - try: - from vllm._custom_ops import silu_and_mul as vllm_silu_and_mul - vllm_silu_and_mul(out, x) - except ImportError: - # Fallback to standard PyTorch - x1 = x[..., :d] - x2 = x[..., d:] - silu_out = torch.nn.functional.silu(x1) - out = silu_out * x2 - + vllm_silu_and_mul(out, x) return out diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py index cb146476..33baa14c 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py @@ -31,32 +31,13 @@ def rmsnorm_cuda( Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ - try: - from vllm._custom_ops import rms_norm as vllm_rms_norm - from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm - - if residual is not None: - vllm_fused_add_rms_norm(x, residual, weight, epsilon) - return x, residual - else: - out = torch.empty_like(x) - vllm_rms_norm(out, x, weight, epsilon) - return out - - except ImportError: - # Fallback to standard PyTorch - orig_dtype = x.dtype - x = x.float() - - if residual is not None: - residual = residual.float() - x = x + residual - - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + epsilon) - x = x * weight - x = x.to(orig_dtype) - - if residual is not None: - return x, residual.to(orig_dtype) - return x + from vllm._custom_ops import rms_norm as vllm_rms_norm + from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm + + if residual is not None: + vllm_fused_add_rms_norm(x, residual, weight, epsilon) + return x, residual + else: + out = torch.empty_like(x) + vllm_rms_norm(out, x, weight, epsilon) + return out diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py index 1bb4bdd7..97711b62 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py @@ -35,61 +35,14 @@ def rotary_embedding_cuda( Returns: Tuple of (embedded_query, embedded_key) """ - try: - from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding - - # vLLM's rotary_embedding modifies tensors in-place - vllm_rotary_embedding( - position_ids, - query, - key, - cos, - sin, - rotary_interleaved, - ) - return query, key - - except ImportError: - # Fallback to standard PyTorch implementation - if position_ids.dim() == 1: - cos_selected = cos[position_ids] - sin_selected = sin[position_ids] - else: - cos_selected = cos[position_ids] - sin_selected = sin[position_ids] - - # Expand dimensions to match query/key shape - if query.dim() == 4: - cos_selected = cos_selected.unsqueeze(1) - sin_selected = sin_selected.unsqueeze(1) - elif query.dim() == 3: - cos_selected = cos_selected.unsqueeze(1) - sin_selected = sin_selected.unsqueeze(1) - - # Check if we need to repeat cos/sin to match head_dim - rotary_dim = cos_selected.shape[-1] - head_dim = query.shape[-1] - - if rotary_dim != head_dim: - cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) - sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) - - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - if rotary_interleaved: - def rotate_interleaved(x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - return torch.stack((-x2, x1), dim=-1).flatten(-2) - - q_embed = (query * cos_selected) + (rotate_interleaved(query) * sin_selected) - k_embed = (key * cos_selected) + (rotate_interleaved(key) * sin_selected) - else: - q_embed = (query * cos_selected) + (rotate_half(query) * sin_selected) - k_embed = (key * cos_selected) + (rotate_half(key) * sin_selected) - - return q_embed, k_embed + from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding + + vllm_rotary_embedding( + position_ids, + query, + key, + cos, + sin, + rotary_interleaved, + ) + return query, key diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index a9d0d29a..29e11a9a 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -170,35 +170,12 @@ def get_attn_backend_cls( selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", ) -> list[str]: - """ - Get the attention backend class path using the dispatch mechanism. - - The dispatch mechanism automatically selects the appropriate backend - based on: - 1. Hardware availability (NPU, CUDA, etc.) - 2. User configuration via environment variables: - - VLLM_FL_PREFER: "vendor" | "reference" | "flaggems" - - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors - - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors - - Backend selection order (default): - - flaggems: FlagGems attention (may have issues) - - vendor: Platform-specific implementation (e.g., Ascend NPU) - - reference: vLLM native attention (PyTorch-based) - - For NPU users, recommend setting VLLM_FL_PREFER=vendor to use - Ascend-optimized attention backend. - """ + """Get the attention backend class path using the dispatch mechanism.""" from vllm_fl.dispatch import call_op use_mla = attn_selector_config.use_mla try: - # Use dispatch mechanism to select the appropriate attention backend - # The dispatch system will automatically choose based on: - # 1. Backend availability (is_available check) - # 2. Policy configuration (VLLM_FL_PREFER, etc.) - # 3. Priority ordering (vendor > reference when vendor is preferred) backend_path = call_op("attention_backend", use_mla=use_mla) logger.info_once( diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index 7df9dd8b..ae2ee0ef 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -217,45 +217,6 @@ def graph_capture(device: torch.device): logger = init_logger(__name__) -def _get_kv_cache_shape_compat( - attn_backend: type, - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", -) -> tuple[int, ...]: - """ - Compatibility wrapper for get_kv_cache_shape that handles different backend signatures. - - vLLM base class signature includes cache_dtype_str parameter, but some backends - (e.g., Ascend) don't support this parameter. This function tries the full signature - first, then falls back to the basic signature if TypeError is raised. - """ - import inspect - - # Check if the backend's get_kv_cache_shape accepts cache_dtype_str - sig = inspect.signature(attn_backend.get_kv_cache_shape) - params = sig.parameters - - if 'cache_dtype_str' in params: - # Backend supports cache_dtype_str parameter - return attn_backend.get_kv_cache_shape( - num_blocks, - block_size, - num_kv_heads, - head_size, - cache_dtype_str=cache_dtype_str, - ) - else: - # Backend doesn't support cache_dtype_str, use basic signature - return attn_backend.get_kv_cache_shape( - num_blocks, - block_size, - num_kv_heads, - head_size, - ) - AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict @@ -3713,17 +3674,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - # TODO(yxa): AttributeError: module 'torch_npu.npu' has no attribute 'OutOfMemoryError' except Exception as e: - # Check if this is an OutOfMemoryError - # For CUDA: torch.cuda.OutOfMemoryError - # For NPU: torch_npu doesn't have OutOfMemoryError, but OOM raises RuntimeError - is_oom = False - if hasattr(current_platform.torch_device_fn, 'OutOfMemoryError'): - is_oom = isinstance(e, current_platform.torch_device_fn.OutOfMemoryError) - else: - # For NPU or other backends without OutOfMemoryError, check error message - is_oom = isinstance(e, RuntimeError) and 'out of memory' in str(e).lower() + is_oom = 'out of memory' in str(e).lower() if is_oom: msg = ( @@ -4844,37 +4796,6 @@ def initialize_metadata_builders( # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() - def _get_graph_support( - self, - builder_cls: type, - vllm_config: VllmConfig, - kv_cache_spec: Any, - ) -> AttentionCGSupport: - """ - Get graph support (CUDA Graph or ACL Graph) from builder class. - - Different backends use different approaches: - - NVIDIA/CUDA: get_cudagraph_support() method - - Ascend: aclgraph_support class variable - - This helper abstracts the difference for multi-backend support. - """ - # Try CUDA Graph support (NVIDIA/CUDA backends) - if hasattr(builder_cls, 'get_cudagraph_support'): - return builder_cls.get_cudagraph_support(vllm_config, kv_cache_spec) - - # Try ACL Graph support (Ascend backend) - if hasattr(builder_cls, 'aclgraph_support'): - return builder_cls.aclgraph_support - - # Default: no graph support - logger.warning( - f"Builder class {builder_cls.__name__} has no graph support " - "(neither get_cudagraph_support method nor aclgraph_support attribute). " - "Defaulting to AttentionCGSupport.NEVER." - ) - return AttentionCGSupport.NEVER - def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], @@ -4895,8 +4816,8 @@ def _check_and_update_cudagraph_mode( for attn_backend in attn_backend_set: builder_cls = attn_backend.get_builder_cls() - cg_support = self._get_graph_support( - builder_cls, self.vllm_config, kv_cache_group.kv_cache_spec + cg_support = builder_cls.get_cudagraph_support( + self.vllm_config, kv_cache_group.kv_cache_spec ) if cg_support.value < min_cg_support.value: min_cg_support = cg_support @@ -5288,13 +5209,12 @@ def _reshape_kv_cache_tensors( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - kv_cache_shape = _get_kv_cache_shape_compat( - attn_backend, + kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - self.cache_config.cache_dtype, + cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype try: From b5cc8d75e1754b67cddd95c7f14921670a70ac09 Mon Sep 17 00:00:00 2001 From: yxa Date: Tue, 20 Jan 2026 07:48:06 +0000 Subject: [PATCH 13/34] Adjust the code. --- vllm_fl/attention/backends/ascend/attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm_fl/attention/backends/ascend/attention.py b/vllm_fl/attention/backends/ascend/attention.py index 6ae158c3..f91490e1 100644 --- a/vllm_fl/attention/backends/ascend/attention.py +++ b/vllm_fl/attention/backends/ascend/attention.py @@ -60,9 +60,11 @@ torch.cuda.Event = torch.npu.Event torch.cuda.Stream = torch.npu.Stream logger.info("NPU compatibility enabled: torch.Event -> torch.npu.Event") -except ImportError: - torch_npu = None - logger.warning("torch_npu not available, Ascend attention backend will not work") +except ImportError as e: + raise ImportError( + "torch_npu is required for Ascend attention backend. " + "Please install torch_npu for NPU support." + ) from e def is_torch_npu_available() -> bool: From 635c7683834b9a22835e08bddef63b8fd47c66df Mon Sep 17 00:00:00 2001 From: yxa Date: Tue, 20 Jan 2026 08:39:12 +0000 Subject: [PATCH 14/34] Place the attention mechanism in the `dispatch` directory. --- vllm_fl/attention/backends/__init__.py | 41 ------------------- vllm_fl/attention/backends/ascend/__init__.py | 40 ------------------ .../dispatch/backends/flaggems/flaggems.py | 4 +- .../backends/flaggems/impl/__init__.py | 18 ++++++++ .../backends/flaggems/impl}/attention.py | 0 .../flaggems/impl}/custom_attention.py | 2 +- .../backends/flaggems/impl}/mla.py | 0 .../dispatch/backends/vendor/ascend/ascend.py | 4 +- .../backends/vendor/ascend/impl/__init__.py | 22 ++++++++++ .../backends/vendor/ascend/impl}/attention.py | 2 +- .../vendor/ascend/impl}/attention_mask.py | 0 vllm_fl/dispatch/ops.py | 4 +- vllm_fl/platform.py | 7 ++-- 13 files changed, 51 insertions(+), 93 deletions(-) delete mode 100644 vllm_fl/attention/backends/__init__.py delete mode 100644 vllm_fl/attention/backends/ascend/__init__.py rename vllm_fl/{attention => dispatch/backends/flaggems/impl}/attention.py (100%) rename vllm_fl/{attention => dispatch/backends/flaggems/impl}/custom_attention.py (71%) rename vllm_fl/{attention => dispatch/backends/flaggems/impl}/mla.py (100%) rename vllm_fl/{attention/backends/ascend => dispatch/backends/vendor/ascend/impl}/attention.py (99%) rename vllm_fl/{attention/backends/ascend => dispatch/backends/vendor/ascend/impl}/attention_mask.py (100%) diff --git a/vllm_fl/attention/backends/__init__.py b/vllm_fl/attention/backends/__init__.py deleted file mode 100644 index 4bd98399..00000000 --- a/vllm_fl/attention/backends/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2026 BAAI. All rights reserved. - -""" -Attention backends for vllm-plugin-FL. - -This module provides attention backend implementations for different hardware platforms. -The dispatch mechanism automatically selects the appropriate backend based on the -available hardware and configuration. - -Available backends: -- ascend: Native Ascend NPU attention using torch_npu operators - - Uses torch_npu.npu_fused_infer_attention_score for prefill - - Uses torch_npu._npu_paged_attention for decode - - No dependency on vllm-ascend package -""" - -from vllm_fl.attention.backends.ascend import ( - AscendAttentionBackend, - AscendAttentionBackendImpl, - AscendAttentionMetadataBuilder, - AscendMetadata, - AscendAttentionState, - AscendMLABackend, - AttentionMaskBuilder, - get_attention_mask_builder, - is_torch_npu_available, -) - -__all__ = [ - # Ascend backend - "AscendAttentionBackend", - "AscendAttentionBackendImpl", - "AscendAttentionMetadataBuilder", - "AscendMetadata", - "AscendAttentionState", - "AscendMLABackend", - # Utilities - "AttentionMaskBuilder", - "get_attention_mask_builder", - "is_torch_npu_available", -] diff --git a/vllm_fl/attention/backends/ascend/__init__.py b/vllm_fl/attention/backends/ascend/__init__.py deleted file mode 100644 index d94ea445..00000000 --- a/vllm_fl/attention/backends/ascend/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2026 BAAI. All rights reserved. - -""" -Ascend NPU attention backend for vllm-plugin-FL. - -This package provides native Ascend NPU attention implementation using torch_npu -operators directly, without depending on vllm-ascend package. - -Modules: -- attention: Core attention backend classes (AscendAttentionBackend, etc.) -- attention_mask: Attention mask builder and utilities -""" - -from vllm_fl.attention.backends.ascend.attention import ( - AscendAttentionBackend, - AscendAttentionBackendImpl, - AscendAttentionMetadataBuilder, - AscendMetadata, - AscendAttentionState, - AscendMLABackend, - is_torch_npu_available, -) -from vllm_fl.attention.backends.ascend.attention_mask import ( - AttentionMaskBuilder, - get_attention_mask_builder, -) - -__all__ = [ - # Attention backend classes - "AscendAttentionBackend", - "AscendAttentionBackendImpl", - "AscendAttentionMetadataBuilder", - "AscendMetadata", - "AscendAttentionState", - "AscendMLABackend", - # Utilities - "AttentionMaskBuilder", - "get_attention_mask_builder", - "is_torch_npu_available", -] diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 1b004d6c..e0ec1a04 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -127,5 +127,5 @@ def attention_backend(self, use_mla: bool = False) -> str: Fully qualified class path string """ if use_mla: - return "vllm_fl.attention.mla.MLAFLBackend" - return "vllm_fl.attention.attention.AttentionFLBackend" + return "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend" + return "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" diff --git a/vllm_fl/dispatch/backends/flaggems/impl/__init__.py b/vllm_fl/dispatch/backends/flaggems/impl/__init__.py index 6dcadfe4..87f1dd34 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/__init__.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/__init__.py @@ -7,9 +7,27 @@ from .activation import silu_and_mul_flaggems from .normalization import rmsnorm_flaggems from .rotary import rotary_embedding_flaggems +from .attention import ( + AttentionFLBackend, + AttentionFLMetadata, + AttentionFLMetadataBuilder, + AttentionFLImpl, +) +from .mla import ( + MLAFLBackend, + MLAFLImpl, +) +from .custom_attention import register_attention __all__ = [ "silu_and_mul_flaggems", "rmsnorm_flaggems", "rotary_embedding_flaggems", + "AttentionFLBackend", + "AttentionFLMetadata", + "AttentionFLMetadataBuilder", + "AttentionFLImpl", + "MLAFLBackend", + "MLAFLImpl", + "register_attention", ] diff --git a/vllm_fl/attention/attention.py b/vllm_fl/dispatch/backends/flaggems/impl/attention.py similarity index 100% rename from vllm_fl/attention/attention.py rename to vllm_fl/dispatch/backends/flaggems/impl/attention.py diff --git a/vllm_fl/attention/custom_attention.py b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py similarity index 71% rename from vllm_fl/attention/custom_attention.py rename to vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py index fccd07de..5aa2860d 100644 --- a/vllm_fl/attention/custom_attention.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py @@ -6,6 +6,6 @@ def register_attention(): register_backend( backend=AttentionBackendEnum.TRITON_ATTN, - class_path="vllm_fl.attention.attention.AttentionFLBackend", + class_path="vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend", is_mamba=False, ) diff --git a/vllm_fl/attention/mla.py b/vllm_fl/dispatch/backends/flaggems/impl/mla.py similarity index 100% rename from vllm_fl/attention/mla.py rename to vllm_fl/dispatch/backends/flaggems/impl/mla.py diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index 3a5ea902..1b573af5 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -143,5 +143,5 @@ def attention_backend(self, use_mla: bool = False) -> str: Fully qualified class path string """ if use_mla: - return "vllm_fl.attention.backends.ascend.AscendMLABackend" - return "vllm_fl.attention.backends.ascend.AscendAttentionBackend" + return "vllm_fl.dispatch.backends.vendor.ascend.impl.attention.AscendMLABackend" + return "vllm_fl.dispatch.backends.vendor.ascend.impl.attention.AscendAttentionBackend" diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py index 1d474b0f..5da827fb 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/__init__.py @@ -7,9 +7,31 @@ from .activation import silu_and_mul_ascend from .normalization import rmsnorm_ascend from .rotary import rotary_embedding_ascend +from .attention import ( + AscendAttentionBackend, + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendMetadata, + AscendAttentionState, + AscendMLABackend, + is_torch_npu_available, +) +from .attention_mask import ( + AttentionMaskBuilder, + get_attention_mask_builder, +) __all__ = [ "silu_and_mul_ascend", "rmsnorm_ascend", "rotary_embedding_ascend", + "AscendAttentionBackend", + "AscendAttentionBackendImpl", + "AscendAttentionMetadataBuilder", + "AscendMetadata", + "AscendAttentionState", + "AscendMLABackend", + "is_torch_npu_available", + "AttentionMaskBuilder", + "get_attention_mask_builder", ] diff --git a/vllm_fl/attention/backends/ascend/attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py similarity index 99% rename from vllm_fl/attention/backends/ascend/attention.py rename to vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py index f91490e1..8f357a27 100644 --- a/vllm_fl/attention/backends/ascend/attention.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py @@ -40,7 +40,7 @@ from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import AttentionCGSupport -from vllm_fl.attention.backends.ascend.attention_mask import ( +from vllm_fl.dispatch.backends.vendor.ascend.impl.attention_mask import ( AttentionMaskBuilder, get_attention_mask_builder, ) diff --git a/vllm_fl/attention/backends/ascend/attention_mask.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention_mask.py similarity index 100% rename from vllm_fl/attention/backends/ascend/attention_mask.py rename to vllm_fl/dispatch/backends/vendor/ascend/impl/attention_mask.py diff --git a/vllm_fl/dispatch/ops.py b/vllm_fl/dispatch/ops.py index aa5e89fd..08792951 100644 --- a/vllm_fl/dispatch/ops.py +++ b/vllm_fl/dispatch/ops.py @@ -141,7 +141,7 @@ def attention_backend(self, use_mla: bool = False) -> str: Returns: Fully qualified class path string, e.g.: - - "vllm_fl.attention.backends.ascend.AscendAttentionBackend" - - "vllm_fl.attention.attention.AttentionFLBackend" + - "vllm_fl.dispatch.backends.vendor.ascend.impl.attention.AscendAttentionBackend" + - "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" """ pass diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 29e11a9a..c5bef2ff 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -41,10 +41,9 @@ def _get_backend( """Get backend priorities with lazy import to avoid circular dependency.""" if use_mla: raise NotImplementedError("NOT support mla now!") - # return "vllm_fl.attention.backends.mla.MLAFLBackend" else: if "USE_FLAGGEMS" in os.environ and os.environ["USE_FLAGGEMS"] == "1": - return [AttentionBackendEnum.TRITON_ATTN] #"vllm_fl.attention.attention.AttentionFLBackend" + return [AttentionBackendEnum.TRITON_ATTN] return [AttentionBackendEnum.FLASH_ATTN] @@ -194,9 +193,9 @@ def get_attn_backend_cls( if cls.device_type == "npu": if use_mla: - backend_path = "vllm_fl.attention.mla.MLAFLBackend" + backend_path = "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend" else: - backend_path = "vllm_fl.attention.attention.AttentionFLBackend" + backend_path = "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" else: # For CUDA and other devices, use vLLM native backend from vllm.attention.backends.registry import AttentionBackendEnum From 3a086d8eaeefb6daff139bace7c16efd5b2c3464 Mon Sep 17 00:00:00 2001 From: yxa Date: Fri, 23 Jan 2026 02:18:01 +0000 Subject: [PATCH 15/34] Fixed bugs, added functionality to read configuration files in dispatch. --- vllm_fl/dispatch/README.md | 129 ++++++++++++++++++++-- vllm_fl/dispatch/__init__.py | 46 ++++++++ vllm_fl/dispatch/manager.py | 68 +++++++++++- vllm_fl/dispatch/policy.py | 200 ++++++++++++++++++++++++++++++++++- 4 files changed, 425 insertions(+), 18 deletions(-) diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index c8ff34b7..add4d1e5 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -18,9 +18,29 @@ dispatch/ └── backends/ # Backend implementations ├── base.py # Backend abstract base class ├── flaggems/ # FlagGems backend (DEFAULT, priority 150) + │ ├── flaggems.py # Backend class + │ ├── register_ops.py # Registration function + │ └── impl/ # Operator implementations + │ ├── activation.py + │ ├── normalization.py + │ ├── rotary.py + │ ├── attention.py # AttentionFLBackend, AttentionFLImpl + │ ├── mla.py # MLAFLBackend, MLAFLImpl + │ └── custom_attention.py # Attention backend registration ├── reference/ # Reference backend (PyTorch, priority 50) └── vendor/ # Vendor-specific backends (priority 100) - └── ascend/ # Example: Huawei Ascend backend + ├── cuda/ # NVIDIA CUDA backend + │ └── impl/ + │ ├── activation.py + │ ├── normalization.py + │ └── rotary.py + └── ascend/ # Huawei Ascend NPU backend + └── impl/ + ├── activation.py + ├── normalization.py + ├── rotary.py + ├── attention.py # AscendAttentionBackend + └── attention_mask.py # Attention mask utilities ``` ## Core Concepts @@ -170,10 +190,82 @@ result = fn(query, key, cos, sin, position_ids) result = manager.call("silu_and_mul", x) ``` -## Environment Variables +## Configuration + +The dispatch system supports two ways to configure backend selection: +1. **Configuration file (YAML)** - Recommended for complex configurations +2. **Environment variables** - Simple, quick configuration + +**Priority**: Configuration file > Environment variables > Default values + +### Configuration File (YAML) + +Set the `VLLM_FL_CONFIG` environment variable to specify a YAML configuration file: + +```bash +export VLLM_FL_CONFIG=/path/to/vllm_fl_dispatch.yaml +``` + +#### Example Configuration File + +```yaml +# vllm_fl_dispatch.yaml + +# Preferred backend type: flaggems, vendor, or reference +prefer: vendor + +# Strict mode: +# true = fail immediately on error, no fallback +# false = try next backend on failure (default) +strict: false + +# Vendor whitelist (optional) +allow_vendors: + - cuda + +# Vendor blacklist (optional) +deny_vendors: + - ascend + +# Per-operator backend selection order (optional) +# Only the backends listed will be tried, in the specified order. +# If you only list 2 options, only those 2 will be attempted. +# +# Supported tokens: +# - flaggems : FlagGems default implementation +# - reference : PyTorch reference implementation +# - vendor : Any available vendor backend (auto-detect) +# - vendor:cuda : Only CUDA vendor backend +# - vendor:ascend : Only Ascend vendor backend +per_op: + rmsnorm: + - vendor # Try any available vendor first + - flaggems # Then try flaggems + # reference not listed, so it won't be used for rmsnorm + + silu_and_mul: + - vendor:cuda # Only try CUDA, not other vendors + - flaggems + - reference +``` + +#### Token Types Explained + +| Token | Description | +|-------|-------------| +| `flaggems` | FlagGems default implementation | +| `reference` | PyTorch reference implementation | +| `vendor` | Any available vendor backend (auto-detects hardware) | +| `vendor:cuda` | Only CUDA vendor backend | +| `vendor:ascend` | Only Ascend vendor backend | + +**Note**: When using `vendor` (without specifying a vendor name), the system automatically selects an available vendor backend based on hardware detection. + +### Environment Variables | Variable | Description | Example | Behavior | |----------|-------------|---------|----------| +| `VLLM_FL_CONFIG` | Path to YAML config file | `/path/to/config.yaml` | Highest priority, overrides other env vars | | `VLLM_FL_PREFER` | Preferred backend (sets selection order) | `flaggems`, `vendor`, `reference` | Defines priority order, falls back if unavailable | | `VLLM_FL_STRICT` | Enable strict mode (auto-fallback on failure) | `1` or `0` | When `1`, tries alternatives if primary fails | | `VLLM_FL_DENY_VENDORS` | Denied vendors list (blacklist) | `vendor1,vendor2` | Excludes specified vendors from selection | @@ -204,15 +296,16 @@ export VLLM_FL_PLUGIN_MODULES=my_custom_backend export VLLM_FL_LOG_LEVEL=DEBUG ``` -### Environment Variable Priority +### Configuration Priority -The dispatch system applies environment variables in the following order: +The dispatch system applies configuration in the following order: -1. **`VLLM_FL_PER_OP`** - Highest priority, overrides default order for specific operators -2. **`VLLM_FL_ALLOW_VENDORS`** - Whitelist filter (if set, only these vendors are allowed) -3. **`VLLM_FL_DENY_VENDORS`** - Blacklist filter (these vendors are excluded) -4. **`VLLM_FL_PREFER`** - Default selection order for all operators -5. **`BackendPriority`** - Code-defined priority (used for tie-breaking within same kind) +1. **`VLLM_FL_CONFIG`** - Highest priority, YAML config file overrides all environment variables +2. **`VLLM_FL_PER_OP`** - Per-operator order overrides default order for specific operators +3. **`VLLM_FL_ALLOW_VENDORS`** - Whitelist filter (if set, only these vendors are allowed) +4. **`VLLM_FL_DENY_VENDORS`** - Blacklist filter (these vendors are excluded) +5. **`VLLM_FL_PREFER`** - Default selection order for all operators +6. **`BackendPriority`** - Code-defined priority (used for tie-breaking within same kind) **Priority values are spaced by 50 to allow future insertion of intermediate priorities:** - `BackendPriority.DEFAULT` = 150 (FlagGems) @@ -271,6 +364,7 @@ Currently supported operators: | `silu_and_mul` | SiLU activation + element-wise multiplication | ✓ | ✓ | ✓ | | `rmsnorm` | RMS normalization | ✓ | ✓ | ✓ | | `rotary_embedding` | Rotary position embedding | ✓ | ✓ | ✓ | +| `attention_backend` | Attention backend class path | ✓ | - | ✓ | ## Selection Process @@ -299,11 +393,16 @@ When adding a new operator, modify these files: - `backends/flaggems/impl/*.py` - Add FlagGems implementation - `backends/flaggems/flaggems.py` - Add method to backend class - `backends/flaggems/register_ops.py` - Register OpImpl -- `backends/reference/impl/*.py` - Add PyTorch implementation +- `backends/reference/impl/*.py` - Add PyTorch implementation (if applicable) - `backends/reference/reference.py` - Add method to backend class - `backends/reference/register_ops.py` - Register OpImpl +- `backends/vendor//impl/*.py` - Add vendor-specific implementation (optional) +- `backends/vendor//.py` - Add method to vendor backend class +- `backends/vendor//register_ops.py` - Register vendor OpImpl - `ops.py` - Add abstract method declaration +**Note:** Not all operators require a reference implementation. For example, `attention_backend` only has FlagGems and vendor implementations since it returns a backend class path rather than executing a computation. + ### Adding Vendor Backends The dispatch system supports three ways to integrate vendor backends: @@ -324,7 +423,8 @@ backends/vendor// ├── __init__.py ├── activation.py ├── normalization.py - └── rotary.py + ├── rotary.py + └── attention.py # (optional) Vendor-specific attention backend ``` **Step 1: Create Backend Class** (`.py`): @@ -456,9 +556,15 @@ export VLLM_FL_LOG_LEVEL=DEBUG - [ ] `impl_id` follows format: `vendor.` - [ ] Priority set to `BackendPriority.VENDOR` (100) - [ ] Error handling for missing dependencies +- [ ] (Optional) `attention_backend()` returns vendor-specific attention backend class path #### Current Vendor Backends +| Vendor | Device | Library | Attention Backend | +|--------|--------|---------|-------------------| +| `cuda` | NVIDIA GPU | `vllm._custom_ops` | - (uses vLLM native) | +| `ascend` | Huawei NPU | `torch_npu` | `AscendAttentionBackend` | + See `backends/vendor/template/` for a template to create new vendor backends. ## Multi-Process Safety @@ -481,6 +587,7 @@ OpManager supports multi-process environments: - `set_global_policy(policy)`: Set global policy - `reset_global_policy()`: Reset to environment variable defaults - `policy_context(policy)`: Temporary policy context +- `policy_from_config(config_path)`: Create policy from YAML config file ### Manager diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index 8658597c..9c48f079 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -19,6 +19,7 @@ result = fn(x, residual, weight, epsilon) Environment Variables: + VLLM_FL_CONFIG: Path to YAML configuration file (highest priority, overrides env vars) VLLM_FL_PREFER: Preferred backend ("flaggems", "vendor", "reference") VLLM_FL_STRICT: Enable strict mode ("1" or "0") VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors @@ -30,6 +31,49 @@ When enabled, prints: - Detailed list of registered operators and implementations at initialization - Selected backend for each operator call + +Configuration File (YAML): + When VLLM_FL_CONFIG is set, the dispatch system loads configuration from the + specified YAML file. Example: + + # vllm_fl_dispatch.yaml + + # Preferred backend type: flaggems, vendor, or reference + prefer: vendor + + # Strict mode: + # true = fail immediately on error, no fallback + # false = try next backend on failure (default) + strict: true + + # Vendor whitelist (optional) + allow_vendors: + - cuda + + # Vendor blacklist (optional) + deny_vendors: + - ascend + + # Per-operator backend selection order (optional) + # Only the backends listed will be tried, in the specified order. + # If you only list 2 options, only those 2 will be attempted. + # + # Supported tokens: + # - flaggems : FlagGems default implementation + # - reference : PyTorch reference implementation + # - vendor : Any available vendor backend (auto-detect) + # - vendor:cuda : Only CUDA vendor backend + # - vendor:ascend : Only Ascend vendor backend + per_op: + rmsnorm: + - vendor # Try any available vendor first + - flaggems # Then try flaggems + # reference not listed, so it won't be used + + silu_and_mul: + - vendor:cuda # Only try CUDA, not other vendors + - flaggems + - reference """ from .types import OpImpl, BackendImplKind, BackendPriority, match_token @@ -41,6 +85,7 @@ set_global_policy, reset_global_policy, policy_context, + policy_from_config, with_strict_mode, with_preference, with_allowed_vendors, @@ -104,6 +149,7 @@ def resolve_op(op_name: str): "set_global_policy", "reset_global_policy", "policy_context", + "policy_from_config", "with_strict_mode", "with_preference", "with_allowed_vendors", diff --git a/vllm_fl/dispatch/manager.py b/vllm_fl/dispatch/manager.py index 09575fe6..db75db4e 100644 --- a/vllm_fl/dispatch/manager.py +++ b/vllm_fl/dispatch/manager.py @@ -10,7 +10,7 @@ import os import threading from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Set, Tuple from .registry import OpRegistry from .policy import SelectionPolicy, get_policy @@ -48,6 +48,7 @@ def __init__(self, registry: Optional[OpRegistry] = None) -> None: self._state = _OpManagerState() self._dispatch_cache: Dict[Tuple[str, str, int], Callable] = {} self._called_ops: Dict[str, str] = {} # Map op_name -> last_used_impl_id + self._failed_impls: Dict[str, Set[str]] = {} # Map op_name -> set of failed impl_ids # Register at_fork handler for multi-process safety try: @@ -69,6 +70,7 @@ def _reset_after_fork(self) -> None: self._state.policy_epoch += 1 self._dispatch_cache.clear() self._called_ops.clear() + self._failed_impls.clear() logger.debug("OpManager reset after fork") def bump_policy_epoch(self) -> None: @@ -80,8 +82,45 @@ def bump_policy_epoch(self) -> None: with self._lock: self._state.policy_epoch += 1 self._dispatch_cache.clear() + self._failed_impls.clear() logger.debug(f"Policy epoch bumped to {self._state.policy_epoch}") + def clear_failed_impls(self, op_name: Optional[str] = None) -> None: + """ + Clear the failed implementations cache. + + This allows previously failed implementations to be retried. + + Args: + op_name: If specified, only clear failed impls for this operator. + If None, clear all failed impls. + """ + with self._lock: + if op_name is None: + self._failed_impls.clear() + logger.debug("Cleared all failed implementations cache") + elif op_name in self._failed_impls: + del self._failed_impls[op_name] + logger.debug(f"Cleared failed implementations cache for op '{op_name}'") + + def get_failed_impls(self, op_name: Optional[str] = None) -> Dict[str, Set[str]]: + """ + Get the failed implementations cache. + + Args: + op_name: If specified, return failed impls only for this operator. + + Returns: + Dict mapping op_name to set of failed impl_ids. + """ + with self._lock: + if op_name is None: + return {k: v.copy() for k, v in self._failed_impls.items()} + elif op_name in self._failed_impls: + return {op_name: self._failed_impls[op_name].copy()} + else: + return {} + def ensure_initialized(self) -> None: """ Ensure the manager is initialized in the current process. @@ -406,7 +445,22 @@ def call(self, op_name: str, *args, **kwargs): candidates = self.resolve_candidates(op_name) last_error = None - for idx, impl in enumerate(candidates): + # Get failed implementations for this op (skip them) + failed_impl_ids = self._failed_impls.get(op_name, set()) + + # Filter out failed implementations + available_candidates = [ + impl for impl in candidates if impl.impl_id not in failed_impl_ids + ] + + if not available_candidates: + # All implementations have failed before, raise error + raise RuntimeError( + f"All implementations for op='{op_name}' have failed previously. " + f"Failed impl_ids: {failed_impl_ids}" + ) + + for idx, impl in enumerate(available_candidates): try: # Log primary implementation or fallback attempts if idx == 0: @@ -444,7 +498,13 @@ def call(self, op_name: str, *args, **kwargs): except Exception as e: last_error = e - if idx < len(candidates) - 1: + # Mark this implementation as failed + with self._lock: + if op_name not in self._failed_impls: + self._failed_impls[op_name] = set() + self._failed_impls[op_name].add(impl.impl_id) + + if idx < len(available_candidates) - 1: # Not the last candidate, log warning and try next logger.warning( f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" @@ -457,7 +517,7 @@ def call(self, op_name: str, *args, **kwargs): # All implementations failed raise RuntimeError( - f"All {len(candidates)} implementation(s) failed for op='{op_name}'. " + f"All {len(available_candidates)} implementation(s) failed for op='{op_name}'. " f"Last error: {last_error}" ) from last_error diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py index 2f90bc9d..24f33166 100644 --- a/vllm_fl/dispatch/policy.py +++ b/vllm_fl/dispatch/policy.py @@ -10,7 +10,13 @@ import os import threading from dataclasses import dataclass, field -from typing import Dict, FrozenSet, List, Optional, Set, Tuple +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple + +try: + import yaml + YAML_AVAILABLE = True +except ImportError: + YAML_AVAILABLE = False # Valid preference values for VLLM_FL_PREFER @@ -239,17 +245,155 @@ def _parse_per_op(value: str) -> Dict[str, List[str]]: return result + def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: + """ + Create a SelectionPolicy from a YAML configuration file. + + Args: + config_path: Path to the YAML configuration file. + + Returns: + SelectionPolicy if successfully loaded, None if file doesn't exist + or YAML is not available. + + Config file format (YAML): + # Preferred backend type: flaggems, vendor, or reference + prefer: vendor + + # Strict mode: + # true = fail immediately on error, no fallback + # false = try next backend on failure (default) + strict: true + + # Vendor whitelist (optional) + allow_vendors: + - cuda + + # Vendor blacklist (optional) + deny_vendors: + - ascend + + # Per-operator backend selection order (optional) + # Only the backends listed will be tried, in the specified order. + # If you only list 2 options, only those 2 will be attempted. + # + # Supported tokens: + # - flaggems : FlagGems default implementation + # - reference : PyTorch reference implementation + # - vendor : Any available vendor backend (auto-detect) + # - vendor:cuda : Only CUDA vendor backend + # - vendor:ascend : Only Ascend vendor backend + per_op: + rmsnorm: + - vendor # Try any available vendor first + - flaggems # Then try flaggems + # reference not listed, so it won't be used + + silu_and_mul: + - vendor:cuda # Only try CUDA, not other vendors + - flaggems + - reference + """ + if not YAML_AVAILABLE: + import warnings + warnings.warn( + f"VLLM_FL_CONFIG is set to '{config_path}' but PyYAML is not installed. " + "Install it with: pip install pyyaml. Falling back to environment variables." + ) + return None + + if not os.path.isfile(config_path): + import warnings + warnings.warn( + f"Config file '{config_path}' not found. Falling back to environment variables." + ) + return None + + try: + with open(config_path, "r", encoding="utf-8") as f: + config: Dict[str, Any] = yaml.safe_load(f) or {} + except Exception as e: + import warnings + warnings.warn( + f"Failed to load config file '{config_path}': {e}. " + "Falling back to environment variables." + ) + return None + + # Parse prefer + prefer_str = str(config.get("prefer", PREFER_DEFAULT)).strip().lower() + if prefer_str not in VALID_PREFER_VALUES: + prefer_str = PREFER_DEFAULT + + # Parse strict + strict_val = config.get("strict", False) + strict = bool(strict_val) + + # Parse deny_vendors + deny_vendors_raw = config.get("deny_vendors") + deny_vendors: Optional[Set[str]] = None + if deny_vendors_raw: + if isinstance(deny_vendors_raw, list): + deny_vendors = {str(v).strip() for v in deny_vendors_raw if v} + elif isinstance(deny_vendors_raw, str): + deny_vendors = self._parse_csv_set(deny_vendors_raw) + + # Parse allow_vendors + allow_vendors_raw = config.get("allow_vendors") + allow_vendors: Optional[Set[str]] = None + if allow_vendors_raw: + if isinstance(allow_vendors_raw, list): + allow_vendors = {str(v).strip() for v in allow_vendors_raw if v} + elif isinstance(allow_vendors_raw, str): + allow_vendors = self._parse_csv_set(allow_vendors_raw) + + # Parse per_op + per_op_raw = config.get("per_op") + per_op_order: Optional[Dict[str, List[str]]] = None + if per_op_raw and isinstance(per_op_raw, dict): + per_op_order = {} + for op_name, order in per_op_raw.items(): + if isinstance(order, list): + per_op_order[str(op_name)] = [str(o).strip() for o in order if o] + elif isinstance(order, str): + # Support string format: "vendor:cuda|flaggems" + per_op_order[str(op_name)] = [ + o.strip() for o in order.split("|") if o.strip() + ] + + return SelectionPolicy.from_dict( + prefer=prefer_str, + strict=strict, + per_op_order=per_op_order, + deny_vendors=deny_vendors, + allow_vendors=allow_vendors, + ) + def _policy_from_env(self) -> SelectionPolicy: """ - Create a SelectionPolicy from environment variables. + Create a SelectionPolicy from configuration file or environment variables. + + Priority: + 1. VLLM_FL_CONFIG: Path to YAML config file (if set and file exists) + 2. Environment variables (VLLM_FL_PREFER, etc.) + 3. Default values Environment variables: + - VLLM_FL_CONFIG: Path to YAML configuration file - VLLM_FL_PREFER: Preference (flaggems, vendor, reference) - VLLM_FL_STRICT: Enable strict mode (1 or 0) - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors - VLLM_FL_PER_OP: Per-op order (format: op1=a|b|c;op2=x|y) """ + # Priority 1: Check for config file + config_path = os.environ.get("VLLM_FL_CONFIG", "").strip() + if config_path: + policy = self._policy_from_config(config_path) + if policy is not None: + return policy + + # Priority 2: Environment variables prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() if prefer_str and prefer_str in VALID_PREFER_VALUES: pass @@ -324,10 +468,60 @@ def reset_global_policy() -> None: def policy_from_env() -> SelectionPolicy: - """Create a SelectionPolicy from environment variables.""" + """Create a SelectionPolicy from configuration file or environment variables.""" return PolicyManager.get_instance()._policy_from_env() +def policy_from_config(config_path: str) -> Optional[SelectionPolicy]: + """ + Create a SelectionPolicy from a YAML configuration file. + + Args: + config_path: Path to the YAML configuration file. + + Returns: + SelectionPolicy if successfully loaded, None if file doesn't exist + or YAML is not available. + + Example config file (YAML): + # Preferred backend type: flaggems, vendor, or reference + prefer: vendor + + # Strict mode: true = fail immediately on error, false = try next backend + strict: true + + # Vendor whitelist (optional) + allow_vendors: + - cuda + + # Vendor blacklist (optional) + deny_vendors: + - ascend + + # Per-operator backend selection order (optional) + # Only the backends listed will be tried, in the specified order. + # If you only list 2 options, only those 2 will be attempted. + # + # Supported tokens: + # - flaggems : FlagGems default implementation + # - reference : PyTorch reference implementation + # - vendor : Any available vendor backend (auto-detect) + # - vendor:cuda : Only CUDA vendor backend + # - vendor:ascend : Only Ascend vendor backend + per_op: + rmsnorm: + - vendor # Try any available vendor first + - flaggems # Then try flaggems + # reference not listed, so it won't be used for rmsnorm + + silu_and_mul: + - vendor:cuda # Only try CUDA, not other vendors + - flaggems + - reference + """ + return PolicyManager.get_instance()._policy_from_config(config_path) + + def policy_context(policy: SelectionPolicy) -> _PolicyContext: """ Create a context manager to temporarily override the policy. From a02a3a012e59896c1a7d001edb68e858eeee14b0 Mon Sep 17 00:00:00 2001 From: yxa Date: Sat, 24 Jan 2026 03:59:56 +0000 Subject: [PATCH 16/34] Modify the code based on PR feedback. --- vllm_fl/dispatch/README.md | 2 +- vllm_fl/dispatch/__init__.py | 2 +- vllm_fl/dispatch/policy.py | 62 +++++++++++++++--------------------- 3 files changed, 27 insertions(+), 39 deletions(-) diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index add4d1e5..ad941ed4 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -237,7 +237,7 @@ deny_vendors: # - vendor : Any available vendor backend (auto-detect) # - vendor:cuda : Only CUDA vendor backend # - vendor:ascend : Only Ascend vendor backend -per_op: +op_backends: rmsnorm: - vendor # Try any available vendor first - flaggems # Then try flaggems diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index 9c48f079..bbe4568b 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -64,7 +64,7 @@ # - vendor : Any available vendor backend (auto-detect) # - vendor:cuda : Only CUDA vendor backend # - vendor:ascend : Only Ascend vendor backend - per_op: + op_backends: rmsnorm: - vendor # Try any available vendor first - flaggems # Then try flaggems diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py index 6a21d27e..6684f0bb 100644 --- a/vllm_fl/dispatch/policy.py +++ b/vllm_fl/dispatch/policy.py @@ -7,16 +7,15 @@ from __future__ import annotations import contextvars +import logging import os import threading from dataclasses import dataclass, field from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple -try: - import yaml - YAML_AVAILABLE = True -except ImportError: - YAML_AVAILABLE = False +import yaml + +logger = logging.getLogger(__name__) # Valid preference values for VLLM_FL_PREFER @@ -245,7 +244,7 @@ def _parse_per_op(value: str) -> Dict[str, List[str]]: return result - def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: + def _policy_from_config(self, config_path: str) -> SelectionPolicy: """ Create a SelectionPolicy from a YAML configuration file. @@ -253,8 +252,11 @@ def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: config_path: Path to the YAML configuration file. Returns: - SelectionPolicy if successfully loaded, None if file doesn't exist - or YAML is not available. + SelectionPolicy loaded from the config file. + + Raises: + FileNotFoundError: If the config file does not exist. + ValueError: If the config file cannot be parsed. Config file format (YAML): # Preferred backend type: flaggems, vendor, or reference @@ -283,7 +285,7 @@ def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: # - vendor : Any available vendor backend (auto-detect) # - vendor:cuda : Only CUDA vendor backend # - vendor:ascend : Only Ascend vendor backend - per_op: + op_backends: rmsnorm: - vendor # Try any available vendor first - flaggems # Then try flaggems @@ -294,31 +296,14 @@ def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: - flaggems - reference """ - if not YAML_AVAILABLE: - import warnings - warnings.warn( - f"VLLM_FL_CONFIG is set to '{config_path}' but PyYAML is not installed. " - "Install it with: pip install pyyaml. Falling back to environment variables." - ) - return None - if not os.path.isfile(config_path): - import warnings - warnings.warn( - f"Config file '{config_path}' not found. Falling back to environment variables." - ) - return None + raise FileNotFoundError(f"Config file '{config_path}' not found.") try: with open(config_path, "r", encoding="utf-8") as f: config: Dict[str, Any] = yaml.safe_load(f) or {} except Exception as e: - import warnings - warnings.warn( - f"Failed to load config file '{config_path}': {e}. " - "Falling back to environment variables." - ) - return None + raise ValueError(f"Failed to load config file '{config_path}': {e}") from e # Parse prefer prefer_str = str(config.get("prefer", PREFER_DEFAULT)).strip().lower() @@ -347,8 +332,8 @@ def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: elif isinstance(allow_vendors_raw, str): allow_vendors = self._parse_csv_set(allow_vendors_raw) - # Parse per_op - per_op_raw = config.get("per_op") + # Parse op_backends + per_op_raw = config.get("op_backends") per_op_order: Optional[Dict[str, List[str]]] = None if per_op_raw and isinstance(per_op_raw, dict): per_op_order = {} @@ -361,6 +346,8 @@ def _policy_from_config(self, config_path: str) -> Optional[SelectionPolicy]: o.strip() for o in order.split("|") if o.strip() ] + logger.info("Using custom config from '%s'", config_path) + return SelectionPolicy.from_dict( prefer=prefer_str, strict=strict, @@ -389,9 +376,7 @@ def _policy_from_env(self) -> SelectionPolicy: # Priority 1: Check for config file config_path = os.environ.get("VLLM_FL_CONFIG", "").strip() if config_path: - policy = self._policy_from_config(config_path) - if policy is not None: - return policy + return self._policy_from_config(config_path) # Priority 2: Environment variables prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() @@ -477,7 +462,7 @@ def policy_from_env() -> SelectionPolicy: return PolicyManager.get_instance()._policy_from_env() -def policy_from_config(config_path: str) -> Optional[SelectionPolicy]: +def policy_from_config(config_path: str) -> SelectionPolicy: """ Create a SelectionPolicy from a YAML configuration file. @@ -485,8 +470,11 @@ def policy_from_config(config_path: str) -> Optional[SelectionPolicy]: config_path: Path to the YAML configuration file. Returns: - SelectionPolicy if successfully loaded, None if file doesn't exist - or YAML is not available. + SelectionPolicy loaded from the config file. + + Raises: + FileNotFoundError: If the config file does not exist. + ValueError: If the config file cannot be parsed. Example config file (YAML): # Preferred backend type: flaggems, vendor, or reference @@ -513,7 +501,7 @@ def policy_from_config(config_path: str) -> Optional[SelectionPolicy]: # - vendor : Any available vendor backend (auto-detect) # - vendor:cuda : Only CUDA vendor backend # - vendor:ascend : Only Ascend vendor backend - per_op: + op_backends: rmsnorm: - vendor # Try any available vendor first - flaggems # Then try flaggems From 7322bfeede447dd4bd4c84d66e572be11da37186 Mon Sep 17 00:00:00 2001 From: yxa Date: Sat, 24 Jan 2026 14:38:13 +0000 Subject: [PATCH 17/34] Cancel the use of attention_backend in flagems --- .../dispatch/backends/flaggems/register_ops.py | 17 +++++++++-------- vllm_fl/dispatch/policy.py | 7 +------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index e516ba1f..b55d2db2 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -65,14 +65,15 @@ def register_builtins(registry) -> None: priority=BackendPriority.DEFAULT, ), # Attention Backend - OpImpl( - op_name="attention_backend", - impl_id="default.flaggems", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.attention_backend, is_avail), - vendor=None, - priority=BackendPriority.DEFAULT, - ), + # TODO: attention_backend 暂时禁用,待调试成功后重新启用 + # OpImpl( + # op_name="attention_backend", + # impl_id="default.flaggems", + # kind=BackendImplKind.DEFAULT, + # fn=_bind_is_available(backend.attention_backend, is_avail), + # vendor=None, + # priority=BackendPriority.DEFAULT, + # ), ] registry.register_many(impls) diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py index 6684f0bb..f878aca2 100644 --- a/vllm_fl/dispatch/policy.py +++ b/vllm_fl/dispatch/policy.py @@ -393,13 +393,8 @@ def _policy_from_env(self) -> SelectionPolicy: allow_str = os.environ.get("VLLM_FL_ALLOW_VENDORS", "").strip() allow_vendors = self._parse_csv_set(allow_str) if allow_str else None - # TODO(xinan): remove this - per_op_order = {"attention_backend": ["vendor", "reference", "flaggems"]} - per_op_str = os.environ.get("VLLM_FL_PER_OP", "").strip() - env_per_op = self._parse_per_op(per_op_str) if per_op_str else None - if env_per_op: - per_op_order.update(env_per_op) + per_op_order = self._parse_per_op(per_op_str) if per_op_str else None return SelectionPolicy.from_dict( prefer=prefer_str, From 896a968799cbce0fff1cfdd307e7e48dc8d90c43 Mon Sep 17 00:00:00 2001 From: yxa Date: Sun, 25 Jan 2026 03:34:24 +0000 Subject: [PATCH 18/34] remove chinese --- vllm_fl/dispatch/backends/flaggems/register_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index b55d2db2..c3cc0b11 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -65,7 +65,6 @@ def register_builtins(registry) -> None: priority=BackendPriority.DEFAULT, ), # Attention Backend - # TODO: attention_backend 暂时禁用,待调试成功后重新启用 # OpImpl( # op_name="attention_backend", # impl_id="default.flaggems", From 1d090d2bef3a66782dc1cf7284914449fa9a3486 Mon Sep 17 00:00:00 2001 From: yxa Date: Tue, 3 Feb 2026 06:22:18 +0000 Subject: [PATCH 19/34] Enable FlagGems attention backend with CUDA availability check --- vllm_fl/dispatch/backends/flaggems/flaggems.py | 14 ++++++++++++-- .../dispatch/backends/flaggems/register_ops.py | 16 ++++++++-------- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 4 ---- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index e0ec1a04..557dce54 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -126,6 +126,16 @@ def attention_backend(self, use_mla: bool = False) -> str: Returns: Fully qualified class path string """ + from vllm.attention.backends.registry import AttentionBackendEnum + + # TritonAttentionBackend requires CUDA, check if available + if not torch.cuda.is_available(): + raise RuntimeError( + "TritonAttentionBackend requires CUDA but CUDA is not available. " + "Falling back to vendor implementation." + ) + if use_mla: - return "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend" - return "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" + raise NotImplementedError("NOT support mla now!") + + return AttentionBackendEnum.TRITON_ATTN.get_path() diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index c3cc0b11..e516ba1f 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -65,14 +65,14 @@ def register_builtins(registry) -> None: priority=BackendPriority.DEFAULT, ), # Attention Backend - # OpImpl( - # op_name="attention_backend", - # impl_id="default.flaggems", - # kind=BackendImplKind.DEFAULT, - # fn=_bind_is_available(backend.attention_backend, is_avail), - # vendor=None, - # priority=BackendPriority.DEFAULT, - # ), + OpImpl( + op_name="attention_backend", + impl_id="default.flaggems", + kind=BackendImplKind.DEFAULT, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor=None, + priority=BackendPriority.DEFAULT, + ), ] registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 453b46ac..396fa9e5 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -117,9 +117,5 @@ def attention_backend(self, use_mla: bool = False) -> str: if use_mla: return AttentionBackendEnum.MLA.get_path() - # Check for TRITON_ATTN preference via environment variable - if os.environ.get("USE_FLAGGEMS", "0") == "1": - return AttentionBackendEnum.TRITON_ATTN.get_path() - # Default to FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN.get_path() From 3dd3242af0a496db9b3c24b74d24f5022d89c8da Mon Sep 17 00:00:00 2001 From: yxa Date: Wed, 4 Feb 2026 08:47:46 +0000 Subject: [PATCH 20/34] [New Feature] Add platform-specific operator config for Ascend/CUDA --- vllm_fl/dispatch/README.md | 257 ++++++++++++++---- .../backends/vendor/ascend/impl/rotary.py | 73 +++-- vllm_fl/dispatch/config/__init__.py | 221 +++++++++++++++ vllm_fl/dispatch/config/ascend.yaml | 106 ++++++++ vllm_fl/dispatch/config/cuda.yaml | 57 ++++ vllm_fl/dispatch/policy.py | 77 ++++-- vllm_fl/ops/custom_ops.py | 11 +- vllm_fl/platform.py | 51 +--- vllm_fl/utils.py | 185 ++++++++----- 9 files changed, 810 insertions(+), 228 deletions(-) create mode 100644 vllm_fl/dispatch/config/__init__.py create mode 100644 vllm_fl/dispatch/config/ascend.yaml create mode 100644 vllm_fl/dispatch/config/cuda.yaml diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 27935f10..3cbc246a 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -15,6 +15,10 @@ dispatch/ ├── ops.py # Backend base interface ├── discovery.py # Plugin discovery mechanism ├── logger_manager.py # Centralized logging configuration +├── config/ # Platform-specific configurations +│ ├── __init__.py # Config loader module +│ ├── ascend.yaml # Ascend NPU default configuration +│ └── cuda.yaml # CUDA default configuration └── backends/ # Backend implementations ├── base.py # Backend abstract base class ├── flaggems/ # FlagGems backend (DEFAULT, priority 150) @@ -192,15 +196,49 @@ result = manager.call("silu_and_mul", x) ## Configuration -The dispatch system supports two ways to configure backend selection: -1. **Configuration file (YAML)** - Recommended for complex configurations -2. **Environment variables** - Simple, quick configuration +The dispatch system supports multiple ways to configure backend selection: +1. **User-specified configuration file (YAML)** - Complete override +2. **Environment variables** - Override specific items +3. **Platform-specific configuration file** - Auto-detected defaults +4. **Built-in default values** -**Priority**: Configuration file > Environment variables > Default values +### Configuration Priority + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Configuration Priority │ +│ (Highest to Lowest) │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. VLLM_FL_CONFIG │ User config file, complete override │ +│ 2. Environment Variables │ Override specific items │ +│ 3. Platform Config File │ ascend.yaml / cuda.yaml defaults │ +│ 4. Built-in Defaults │ Code-defined default values │ +└─────────────────────────────────────────────────────────────────┘ +``` + +**Key Points:** +- Environment variables can override specific items from platform config +- If user doesn't set any environment variable, platform config is used +- Users can also modify platform config files directly -### Configuration File (YAML) +### Platform-Specific Configuration -Set the `VLLM_FL_CONFIG` environment variable to specify a YAML configuration file: +The system automatically detects hardware and loads the corresponding configuration file from `config/` directory: + +| Platform | Config File | Auto-Detection | +|----------|-------------|----------------| +| Ascend NPU | `config/ascend.yaml` | `torch.npu.is_available()` | +| NVIDIA GPU | `config/cuda.yaml` | `torch.cuda.is_available()` | + +You can force a specific platform using `VLLM_FL_PLATFORM` environment variable: +```bash +export VLLM_FL_PLATFORM=ascend # Force Ascend config +export VLLM_FL_PLATFORM=cuda # Force CUDA config +``` + +### User-Specified Configuration File (YAML) + +Set the `VLLM_FL_CONFIG` environment variable to specify a YAML configuration file that completely overrides all other settings: ```bash export VLLM_FL_CONFIG=/path/to/vllm_fl_dispatch.yaml @@ -229,14 +267,6 @@ deny_vendors: # Per-operator backend selection order (optional) # Only the backends listed will be tried, in the specified order. -# If you only list 2 options, only those 2 will be attempted. -# -# Supported tokens: -# - flagos : FlagOS default implementation -# - reference : PyTorch reference implementation -# - vendor : Any available vendor backend (auto-detect) -# - vendor:cuda : Only CUDA vendor backend -# - vendor:ascend : Only Ascend vendor backend op_backends: rms_norm: - vendor # Try any available vendor first @@ -247,6 +277,18 @@ op_backends: - vendor:cuda # Only try CUDA, not other vendors - flagos - reference + +# FlagGems operator blacklist (optional) +# These operators will NOT use FlagGems implementation +flaggems_blacklist: + - to_copy + - zeros + - mm + +# OOT operator blacklist (optional) +# These operators will NOT be registered as OOT replacements +oot_blacklist: + - fused_moe ``` #### Token Types Explained @@ -264,39 +306,92 @@ op_backends: ### Environment Variables -| Variable | Description | Example | Behavior | -|----------|-------------|---------|----------| -| `VLLM_FL_CONFIG` | Path to YAML config file | `/path/to/config.yaml` | Highest priority, overrides other env vars | -| `VLLM_FL_PREFER` | Preferred backend (sets selection order) | `flagos`, `vendor`, `reference` | Defines priority order, falls back if unavailable | -| `VLLM_FL_PREFER_ENABLED` | Global backend switch | `true` or `false` | Default `true`, `false` disables all backends to keep native vllm | -| `VLLM_FL_FLAGOS_WHITELIST` | FlagGems ops to enable | `silu_and_mul,rms_norm` | Only these ops are enabled | -| `VLLM_FL_FLAGOS_BLACKLIST` | FlagGems ops to disable | `rotary_embedding` | Disabled even if otherwise selected | -| `VLLM_FL_STRICT` | Enable strict mode (auto-fallback on failure) | `1` or `0` | When `1`, tries alternatives if primary fails | -| `VLLM_FL_DENY_VENDORS` | Denied vendors list (blacklist) | `vendor1,vendor2` | Excludes specified vendors from selection | -| `VLLM_FL_ALLOW_VENDORS` | Allowed vendors whitelist | `vendor1,vendor2` | Only allows specified vendors (if set) | -| `VLLM_FL_PER_OP` | Per-operator selection order | `op1=a\|b\|c;op2=x\|y` | Overrides default order for specific ops | -| `VLLM_FL_PLUGIN_MODULES` | Plugin modules to load | `my_plugin,another_plugin` | Loads external plugin modules | -| `VLLM_FL_LOG_LEVEL` | Log level | `DEBUG`, `INFO`, `WARNING`, `ERROR` | Controls logging verbosity | +Environment variables can override specific items from platform config. If not set, values from platform config file are used. + +#### Core Configuration + +| Variable | Default | Description | +|----------|---------|-------------| +| `VLLM_FL_PREFER_ENABLED` | `true` | Global switch. Set `false` to disable all dispatch features | +| `VLLM_FL_CONFIG` | (none) | Path to YAML config file (complete override) | +| `VLLM_FL_PLATFORM` | (auto) | Force platform: `ascend`, `cuda` | + +#### Backend Selection + +| Variable | Default | Description | +|----------|---------|-------------| +| `VLLM_FL_PREFER` | `flagos` | Preferred backend: `flagos`, `vendor`, `reference` | +| `VLLM_FL_STRICT` | `0` | Strict mode: `1` = fail on error, `0` = try fallback | +| `VLLM_FL_PER_OP` | (none) | Per-operator order: `op1=a\|b\|c;op2=x\|y` | +| `VLLM_FL_ALLOW_VENDORS` | (none) | Vendor whitelist, comma-separated | +| `VLLM_FL_DENY_VENDORS` | (none) | Vendor blacklist, comma-separated | + +#### FlagGems Control + +| Variable | Default | Description | +|----------|---------|-------------| +| `USE_FLAGGEMS` | `true` | Enable/disable FlagGems | +| `VLLM_FL_FLAGOS_WHITELIST` | (none) | FlagGems ops whitelist (mutually exclusive with blacklist) | +| `VLLM_FL_FLAGOS_BLACKLIST` | (none) | FlagGems ops blacklist (mutually exclusive with whitelist) | + +**Priority**: `WHITELIST` > `BLACKLIST` (env) > `flaggems_blacklist` (config file) + +#### OOT Operator Control + +| Variable | Default | Description | +|----------|---------|-------------| +| `VLLM_FL_OOT_ENABLED` | `1` | Enable OOT operator registration | +| `VLLM_FL_OOT_WHITELIST` | (none) | OOT ops whitelist | +| `VLLM_FL_OOT_BLACKLIST` | (none) | OOT ops blacklist | + +**Priority**: `WHITELIST` > `BLACKLIST` (env) > `oot_blacklist` (config file) + +#### Debug & Logging + +| Variable | Default | Description | +|----------|---------|-------------| +| `VLLM_FL_LOG_LEVEL` | `INFO` | Log level: `DEBUG`, `INFO`, `WARNING`, `ERROR` | +| `VLLM_FL_DISPATCH_DEBUG` | `0` | Enable dispatch debug mode | + +#### Plugins + +| Variable | Default | Description | +|----------|---------|-------------| +| `VLLM_FL_PLUGIN_MODULES` | (none) | External plugin modules, comma-separated | +| `VLLM_FL_OP_CONFIG` | (none) | Operator config JSON file path | + +#### Other + +| Variable | Default | Description | +|----------|---------|-------------| +| `FLAGCX_PATH` | (none) | FlagCX library path (enables FlagCX communication backend) | +| `FLAGGEMS_ENABLE_OPLIST_PATH` | `/tmp/flaggems_enable_oplist.txt` | FlagGems enabled ops list file | ### Examples ```bash -# Prefer FlagGems implementation -export VLLM_FL_PREFER=flagos +# Use platform default config (auto-detected) +# Nothing to set - just run your application + +# Override only the prefer setting (other items from platform config) +export VLLM_FL_PREFER=vendor -# Enable strict mode (auto-fallback on failure) -export VLLM_FL_STRICT=1 +# Override FlagGems blacklist (overrides config file blacklist) +export VLLM_FL_FLAGOS_BLACKLIST="mm,to_copy,zeros" -# Deny specific vendors -export VLLM_FL_DENY_VENDORS=vendor_a,vendor_b +# Use whitelist instead (completely ignores any blacklist) +export VLLM_FL_FLAGOS_WHITELIST="silu_and_mul,rms_norm" -# Specify selection order for specific operator +# Specify per-operator order export VLLM_FL_PER_OP="rms_norm=vendor|flagos|reference" -# Load external plugins -export VLLM_FL_PLUGIN_MODULES=my_custom_backend +# Use completely custom config file +export VLLM_FL_CONFIG=/path/to/my_config.yaml -# Set log level +# Force specific platform +export VLLM_FL_PLATFORM=ascend + +# Enable debug logging export VLLM_FL_LOG_LEVEL=DEBUG ``` @@ -312,39 +407,87 @@ op_backends: - reference ``` -### Configuration Priority +### Configuration Priority Details The dispatch system applies configuration in the following order: -1. **`VLLM_FL_CONFIG`** - Highest priority, YAML config file overrides all environment variables -2. **`VLLM_FL_PER_OP`** - Per-operator order overrides default order for specific operators -3. **`VLLM_FL_ALLOW_VENDORS`** - Whitelist filter (if set, only these vendors are allowed) -4. **`VLLM_FL_DENY_VENDORS`** - Blacklist filter (these vendors are excluded) -5. **`VLLM_FL_PREFER`** - Default selection order for all operators -6. **`BackendPriority`** - Code-defined priority (used for tie-breaking within same kind) +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Configuration Resolution │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ VLLM_FL_CONFIG set? │ +│ │ │ +│ ├── Yes ──▶ Use user config file (complete override) │ +│ │ │ +│ └── No ──▶ For each setting: │ +│ │ │ +│ ├── Env var set? ──▶ Use env var value │ +│ │ │ +│ └── Not set ──▶ Use platform config value │ +│ │ │ +│ └── Not found ──▶ Default│ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` -**Priority values are spaced by 50 to allow future insertion of intermediate priorities:** -- `BackendPriority.DEFAULT` = 150 (FlagGems) -- `BackendPriority.VENDOR` = 100 (Vendor-specific) -- `BackendPriority.REFERENCE` = 50 (PyTorch) +#### Whitelist vs Blacklist Priority + +For FlagGems and OOT operators: + +``` +WHITELIST (env) ──▶ Completely overrides blacklist + │ + └── Not set ──▶ BLACKLIST (env) ──▶ Overrides config blacklist + │ + └── Not set ──▶ Config file blacklist + │ + └── Not set ──▶ Allow all +``` + +**Important Notes:** +- Whitelist and blacklist environment variables are mutually exclusive (error if both set) +- If whitelist is set, it completely ignores any blacklist (env or config) +- Environment blacklist overrides config file blacklist (not merged) #### Example: Combined Environment Variables ```bash -export VLLM_FL_PREFER=flagos # Default: flagos → vendor → reference -export VLLM_FL_DENY_VENDORS=vendor_a # Exclude vendor_a -export VLLM_FL_PER_OP="rms_norm=vendor|reference" # Override for rms_norm only +# Platform config (ascend.yaml) has: +# prefer: flagos +# flaggems_blacklist: [to_copy, zeros, mm, ...] + +# User overrides only prefer, blacklist still from config +export VLLM_FL_PREFER=vendor + +# Result: +# prefer: vendor (from env) +# flaggems_blacklist: [to_copy, zeros, mm, ...] (from config) ``` -**Result:** -- **`rms_norm` operator**: Uses `vendor → reference` order (PER_OP overrides PREFER), excluding vendor_a -- **Other operators** (e.g., `silu_and_mul`): Uses `flagos → vendor → reference` order (from PREFER), excluding vendor_a +```bash +# User wants to override blacklist too +export VLLM_FL_PREFER=vendor +export VLLM_FL_FLAGOS_BLACKLIST="custom_op1,custom_op2" + +# Result: +# prefer: vendor (from env) +# flaggems_blacklist: [custom_op1, custom_op2] (from env, config ignored) +``` #### Important Notes -- **`VLLM_FL_PREFER` sets preference, not exclusivity**: It defines the selection order but will fall back to other backends if the preferred one is unavailable. -- **To force a specific backend**: Combine `PREFER` with `DENY_VENDORS` or use `PER_OP` to exclude unwanted backends. -- **`VLLM_FL_STRICT=1`**: Enables automatic fallback when the primary implementation fails at runtime (not just unavailable). +- **Environment variables override, not merge**: Setting an env var replaces the config value entirely +- **`VLLM_FL_PREFER` sets preference, not exclusivity**: It defines the selection order but will fall back to other backends if the preferred one is unavailable +- **To force a specific backend**: Combine `PREFER` with `DENY_VENDORS` or use `PER_OP` to exclude unwanted backends +- **`VLLM_FL_STRICT=1`**: Enables automatic fallback when the primary implementation fails at runtime + +#### Backend Priority Values + +Priority values are spaced by 50 to allow future insertion of intermediate priorities: +- `BackendPriority.DEFAULT` = 150 (FlagGems) +- `BackendPriority.VENDOR` = 100 (Vendor-specific) +- `BackendPriority.REFERENCE` = 50 (PyTorch) ## Policy Context Management diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py index 443dd562..d92cae19 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -2,6 +2,7 @@ """ Ascend rotary embedding operator implementations. +Based on vllm-ascend official implementation. """ from __future__ import annotations @@ -22,12 +23,12 @@ def rotary_embedding_ascend( Apply rotary position embedding using Ascend NPU. Args: - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary + query: Query tensor [num_tokens, num_heads, rotary_dim] + key: Key tensor [num_tokens, num_kv_heads, rotary_dim] + cos: Cosine cache [max_seq_len, rotary_dim // 2] + sin: Sine cache [max_seq_len, rotary_dim // 2] + position_ids: Position indices [num_tokens] + rotary_interleaved: Whether to use interleaved rotary (False = neox style) inplace: Whether to modify tensors in-place Returns: @@ -35,42 +36,38 @@ def rotary_embedding_ascend( """ import torch_npu - # Get cos/sin for the positions - if position_ids.dim() == 1: - cos_selected = cos[position_ids] - sin_selected = sin[position_ids] - else: - cos_selected = cos[position_ids] - sin_selected = sin[position_ids] + # query/key shape: [num_tokens, num_heads, rotary_dim] + num_tokens = query.shape[0] + rotary_dim = query.shape[-1] - # Prepare cos/sin shape for npu_rotary_mul: [1, seq_len, 1, head_dim] - head_dim = query.shape[-1] - rotary_dim = cos_selected.shape[-1] + # Reconstruct cos_sin_cache from separate cos and sin + # cos/sin: [max_seq_len, rotary_dim // 2] + # cos_sin_cache: [max_seq_len, rotary_dim] where first half is cos, second half is sin + cos_sin_cache = torch.cat([cos, sin], dim=-1) - # Duplicate cos/sin if needed to match head_dim - if rotary_dim != head_dim: - cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) - sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) - - # Reshape cos/sin to [1, seq_len, 1, head_dim] - cos_selected = cos_selected.reshape(1, -1, 1, head_dim) - sin_selected = sin_selected.reshape(1, -1, 1, head_dim) - - # Reshape query/key to [1, seq_len, num_heads, head_dim] + # Save original shapes query_shape = query.shape key_shape = key.shape - if query.dim() == 3: - query = query.unsqueeze(0) - if key.dim() == 3: - key = key.unsqueeze(0) - - # Apply rotary embedding using NPU kernel - q_embed = torch_npu.npu_rotary_mul(query, cos_selected, sin_selected) - k_embed = torch_npu.npu_rotary_mul(key, cos_selected, sin_selected) - - # Restore original shape - q_embed = q_embed.view(query_shape) - k_embed = k_embed.view(key_shape) + # Flatten query/key for _npu_rotary_embedding: [num_tokens, num_heads * rotary_dim] + query_flat = query.contiguous().view(num_tokens, -1) + key_flat = key.contiguous().view(num_tokens, -1) + + # is_neox_style is the opposite of rotary_interleaved + is_neox_style = not rotary_interleaved + + # Apply rotary embedding using NPU kernel (in-place operation) + torch_npu._npu_rotary_embedding( + position_ids, + query_flat, + key_flat, + rotary_dim, # head_size = rotary_dim + cos_sin_cache, + is_neox_style, + ) + + # Restore original shapes + q_embed = query_flat.view(query_shape) + k_embed = key_flat.view(key_shape) return q_embed, k_embed diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py new file mode 100644 index 00000000..578bee08 --- /dev/null +++ b/vllm_fl/dispatch/config/__init__.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Hardware-specific operator configuration loader. + +This module provides automatic loading of operator configurations based on +the detected hardware platform. + +Configuration Priority (highest to lowest): +1. VLLM_FL_CONFIG: User-specified config file path (complete override) +2. Environment variables: Override specific items from platform config + - VLLM_FL_PREFER: Backend preference (flagos, vendor, reference) + - VLLM_FL_STRICT: Strict mode (1 or 0) + - VLLM_FL_PER_OP: Per-operator backend order + - VLLM_FL_FLAGOS_BLACKLIST: FlagGems operator blacklist + - VLLM_FL_OOT_BLACKLIST: OOT operator blacklist +3. Platform-specific config file: Default values (auto-detected) +4. Built-in default values + +Supported platforms: +- ascend: Huawei Ascend NPU +- cuda: NVIDIA GPU +- (more platforms can be added) + +Configuration files are stored in this directory as YAML files: +- ascend.yaml +- cuda.yaml +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Optional + +import yaml + +# Directory containing this file (config/) +_CONFIG_DIR = Path(__file__).parent + + +def get_platform_name() -> str: + """ + Detect the current hardware platform. + + Returns: + Platform name string: 'ascend', 'cuda', or 'unknown' + """ + try: + import torch + if hasattr(torch, 'npu') and torch.npu.is_available(): + return 'ascend' + if torch.cuda.is_available(): + return 'cuda' + except ImportError: + pass + + # Check environment variable override + platform_override = os.environ.get('VLLM_FL_PLATFORM', '').strip().lower() + if platform_override: + return platform_override + + return 'unknown' + + +def get_config_path(platform: Optional[str] = None) -> Optional[Path]: + """ + Get the configuration file path for the specified or detected platform. + + Args: + platform: Platform name. If None, auto-detect. + + Returns: + Path to the config file, or None if not found. + """ + if platform is None: + platform = get_platform_name() + + # Try platform-specific config + config_file = _CONFIG_DIR / f"{platform}.yaml" + if config_file.exists(): + return config_file + + return None + + +def load_platform_config(platform: Optional[str] = None) -> Optional[dict[str, Any]]: + """ + Load the configuration for the specified or detected platform. + + Args: + platform: Platform name. If None, auto-detect. + + Returns: + Configuration dictionary, or None if no config found. + """ + config_path = get_config_path(platform) + if config_path is None: + return None + + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config if isinstance(config, dict) else None + except Exception: + return None + + +def get_per_op_order(config: Optional[dict] = None) -> Optional[dict[str, list[str]]]: + """ + Extract per-op backend order from config. + + Args: + config: Configuration dict. If None, load from platform config. + + Returns: + Dict mapping op names to backend order lists. + """ + if config is None: + config = load_platform_config() + if config is None: + return None + + per_op = config.get('per_op', {}) + if not isinstance(per_op, dict): + return None + + result = {} + for op_name, backends in per_op.items(): + if isinstance(backends, list): + result[op_name] = [str(b) for b in backends] + elif isinstance(backends, str): + result[op_name] = [backends] + + return result if result else None + + +def get_flaggems_blacklist(config: Optional[dict] = None) -> Optional[list[str]]: + """ + Extract FlagGems operator blacklist from config. + + Args: + config: Configuration dict. If None, load from platform config. + + Returns: + List of blacklisted FlagGems operator names. + """ + if config is None: + config = load_platform_config() + if config is None: + return None + + blacklist = config.get('flaggems_blacklist', []) + if isinstance(blacklist, list): + return [str(op) for op in blacklist] + return None + + +def get_oot_blacklist(config: Optional[dict] = None) -> Optional[list[str]]: + """ + Extract OOT operator blacklist from config. + + Args: + config: Configuration dict. If None, load from platform config. + + Returns: + List of blacklisted OOT operator names. + """ + if config is None: + config = load_platform_config() + if config is None: + return None + + blacklist = config.get('oot_blacklist', []) + if isinstance(blacklist, list): + return [str(op) for op in blacklist] + return None + + +def get_effective_config() -> dict[str, Any]: + """ + Get the effective configuration, considering environment variable overrides. + + Priority: + 1. VLLM_FL_CONFIG environment variable (user-specified config file) + 2. Platform-specific config file (auto-detected) + 3. Default config file + 4. Empty config (no restrictions) + + Returns: + Effective configuration dictionary. + """ + # Check for user-specified config file + user_config_path = os.environ.get('VLLM_FL_CONFIG', '').strip() + if user_config_path and os.path.isfile(user_config_path): + try: + with open(user_config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + if isinstance(config, dict): + return config + except Exception: + pass + + # Load platform config + platform_config = load_platform_config() + if platform_config: + return platform_config + + # Return empty config + return {} + + +__all__ = [ + 'get_platform_name', + 'get_config_path', + 'load_platform_config', + 'get_per_op_order', + 'get_flaggems_blacklist', + 'get_oot_blacklist', + 'get_effective_config', +] diff --git a/vllm_fl/dispatch/config/ascend.yaml b/vllm_fl/dispatch/config/ascend.yaml new file mode 100644 index 00000000..5d7f7b54 --- /dev/null +++ b/vllm_fl/dispatch/config/ascend.yaml @@ -0,0 +1,106 @@ +# vLLM-FL Dispatch Configuration for Ascend NPU +# Auto-loaded when running on Ascend hardware + +# Preferred default backend type: flaggems, vendor, reference +prefer: flagos + +# Strict Mode: +# true = Raise an error immediately on failure; do not attempt other backends. +# false = Attempt the next available backend in sequence upon failure (Default). +strict: false + +# Vendor Whitelist (Optional, allows all if not set) +# allow_vendors: +# - ascend +# - cuda + +# Vendor Blacklist (Optional) +# deny_vendors: +# - cuda + +# Per-operator backend execution order (Optional) +# Only the backends listed here will be attempted, in the order specified. +# +# Supported tokens: +# - flaggems : Default FlagGems implementation +# - reference : PyTorch reference implementation +# - vendor : Any available vendor backend (auto-detected) +# - vendor:ascend : Ascend-specific vendor backend +op_backends: + # attention_backend: prioritize vendor as FlagGems implementation is incomplete + attention_backend: + - vendor + - flagos + - reference + + # rms_norm: prioritize vendor (torch_npu) implementation + rms_norm: + - vendor + - flagos + - reference + + # silu_and_mul: prioritize vendor implementation + silu_and_mul: + - flagos + - vendor + - reference + + # rotary_embedding: prioritize vendor (torch_npu._npu_rotary_embedding) + rotary_embedding: + - vendor + - flagos + - reference + +# FlagGems operator blacklist +# These operators will NOT use FlagGems implementation even if FlagGems is enabled. +flaggems_blacklist: + - to_copy + - _to_copy + - zeros + - acos + - copy_ + - fill_scalar_ + - sum_dim + - exponential_ + - mm + - resolve_neg + - resolve_conj + - eq_scalar + - floor_divide + - cumsum + - cumsum_out + - mul + - reciprocal + - repeat + - randn + - add + - ge_scalar + - sub + - bitwise_and + - bitwise_not + - slice_scatter + - fill_tensor_ + - conv1d + - conv2d + - uniform_ + - prod + - max + - amax + - cat + - stack + - flatten + - reshape + - view + - tensor + - rand + - rand_like + - gather + - _log_softmax + - _flash_attention_forward + - scatter + - scatter_ + +# OOT (Out-of-Tree) operator blacklist +# These operators will NOT be registered as OOT replacements. +# oot_blacklist: +# - fused_moe diff --git a/vllm_fl/dispatch/config/cuda.yaml b/vllm_fl/dispatch/config/cuda.yaml new file mode 100644 index 00000000..6ad2d6dc --- /dev/null +++ b/vllm_fl/dispatch/config/cuda.yaml @@ -0,0 +1,57 @@ +# vLLM-FL Dispatch Configuration for CUDA +# Auto-loaded when running on NVIDIA GPU hardware + +# Preferred default backend type: flaggems, vendor, reference +prefer: flagos + +# Strict Mode: +# true = Raise an error immediately on failure; do not attempt other backends. +# false = Attempt the next available backend in sequence upon failure (Default). +strict: false + +# Vendor Whitelist (Optional, allows all if not set) +# allow_vendors: +# - cuda + +# Vendor Blacklist (Optional) +# deny_vendors: +# - ascend + +# Per-operator backend execution order (Optional) +# Only the backends listed here will be attempted, in the order specified. +# +# Supported tokens: +# - flaggems : Default FlagGems implementation (Triton) +# - reference : PyTorch reference implementation +# - vendor : Any available vendor backend (auto-detected) +# - vendor:cuda : CUDA-specific vendor backend +op_backends: + # attention_backend: prioritize flaggems (Triton attention) + attention_backend: + - flagos + - vendor + - reference + + # rms_norm: prioritize flaggems (Triton) + rms_norm: + - flagos + - vendor + - reference + + # silu_and_mul: prioritize flaggems + silu_and_mul: + - flagos + - vendor + - reference + + # rotary_embedding: prioritize flaggems + rotary_embedding: + - flagos + - vendor + - reference + +# FlagGems operator blacklist +# flaggems_blacklist: [] + +# OOT (Out-of-Tree) operator blacklist +# oot_blacklist: [] diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py index cc2c534f..f1523e80 100644 --- a/vllm_fl/dispatch/policy.py +++ b/vllm_fl/dispatch/policy.py @@ -377,50 +377,87 @@ def _policy_from_env(self) -> SelectionPolicy: """ Create a SelectionPolicy from configuration file or environment variables. - Priority: - 1. VLLM_FL_CONFIG: Path to YAML config file (if set and file exists) - 2. Environment variables (VLLM_FL_PREFER, etc.) - 3. Default values + Priority (highest to lowest): + 1. VLLM_FL_CONFIG: Path to YAML config file (if set, completely overrides) + 2. Environment variables: Override specific items from platform config + 3. Platform-specific config file: Default values (auto-detected) + 4. Built-in default values Environment variables: - - VLLM_FL_CONFIG: Path to YAML configuration file + - VLLM_FL_CONFIG: Path to YAML configuration file (complete override) - VLLM_FL_PREFER: Preference (flagos, vendor, reference) - VLLM_FL_STRICT: Enable strict mode (1 or 0) - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors - VLLM_FL_PER_OP: Per-op order (format: op1=a|b|c;op2=x|y) """ - # Priority 1: Check for config file + # Priority 1: Check for user-specified config file (complete override) config_path = os.environ.get("VLLM_FL_CONFIG", "").strip() - if config_path: + if config_path and os.path.isfile(config_path): return self._policy_from_config(config_path) - # Priority 2: Environment variables - prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() - if prefer_str and prefer_str in VALID_PREFER_VALUES: - pass + # Priority 3: Load platform-specific config as base defaults + from vllm_fl.dispatch.config import get_config_path + platform_config_path = get_config_path() + platform_policy = None + if platform_config_path: + try: + platform_policy = self._policy_from_config(str(platform_config_path)) + except Exception as e: + logger.warning("Failed to load platform config: %s", e) + + # Priority 2: Environment variables override platform config + # Get values from environment variables + env_prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() + env_strict_str = os.environ.get("VLLM_FL_STRICT", "").strip() + env_deny_str = os.environ.get("VLLM_FL_DENY_VENDORS", "").strip() + env_allow_str = os.environ.get("VLLM_FL_ALLOW_VENDORS", "").strip() + env_per_op_str = os.environ.get("VLLM_FL_PER_OP", "").strip() + + # Determine final values: env var > platform config > default + if env_prefer_str and env_prefer_str in VALID_PREFER_VALUES: + prefer_str = env_prefer_str + elif platform_policy: + prefer_str = platform_policy.prefer else: prefer_str = PREFER_DEFAULT - strict = os.environ.get("VLLM_FL_STRICT", "0").strip() == "1" + if env_strict_str: + strict = env_strict_str == "1" + elif platform_policy: + strict = platform_policy.strict + else: + strict = False - deny_str = os.environ.get("VLLM_FL_DENY_VENDORS", "").strip() - deny_vendors = self._parse_csv_set(deny_str) if deny_str else None + if env_deny_str: + deny_vendors = self._parse_csv_set(env_deny_str) + elif platform_policy and platform_policy.deny_vendors: + deny_vendors = set(platform_policy.deny_vendors) + else: + deny_vendors = None - allow_str = os.environ.get("VLLM_FL_ALLOW_VENDORS", "").strip() - allow_vendors = self._parse_csv_set(allow_str) if allow_str else None + if env_allow_str: + allow_vendors = self._parse_csv_set(env_allow_str) + elif platform_policy and platform_policy.allow_vendors: + allow_vendors = set(platform_policy.allow_vendors) + else: + allow_vendors = None + # Per-op order: env var > op_config > platform config op_config = get_op_config() if op_config: - env_per_op = self._parse_op_config(op_config) + per_op_order = self._parse_op_config(op_config) + elif env_per_op_str: + per_op_order = self._parse_per_op(env_per_op_str) + elif platform_policy and platform_policy.per_op_order: + per_op_order = platform_policy.per_op_order_dict else: - per_op_str = os.environ.get("VLLM_FL_PER_OP", "").strip() - env_per_op = self._parse_per_op(per_op_str) if per_op_str else None + per_op_order = None return SelectionPolicy.from_dict( prefer=prefer_str, strict=strict, - per_op_order=env_per_op, + per_op_order=per_op_order, deny_vendors=deny_vendors, allow_vendors=allow_vendors, ) diff --git a/vllm_fl/ops/custom_ops.py b/vllm_fl/ops/custom_ops.py index aa035bcf..e73c3153 100644 --- a/vllm_fl/ops/custom_ops.py +++ b/vllm_fl/ops/custom_ops.py @@ -37,13 +37,19 @@ def register_oot_ops(whitelist: Optional[List[str]] = None) -> None: whitelist: If provided, only register operators in this list. If None, check VLLM_FL_OOT_WHITELIST env var. If neither is set, register all operators. + + Operators in VLLM_FL_OOT_BLACKLIST or platform config oot_blacklist + will be excluded from registration. """ - from vllm_fl.utils import get_oot_whitelist, is_oot_enabled, use_flaggems_op + from vllm_fl.utils import get_oot_blacklist, get_oot_whitelist, is_oot_enabled, use_flaggems_op # Check if OOT registration is enabled if not is_oot_enabled(): return + # Get blacklist (from env var or platform config) + blacklist = get_oot_blacklist() or [] + # Determine which operators to register env_whitelist = get_oot_whitelist() if env_whitelist is not None: @@ -53,6 +59,9 @@ def register_oot_ops(whitelist: Optional[List[str]] = None) -> None: else: ops_to_register = list(OOT_OPS.keys()) + # Apply blacklist + ops_to_register = [op for op in ops_to_register if op not in blacklist] + for op_name in ops_to_register: if op_name not in OOT_OPS: logger.warning(f"OOT op '{op_name}' not found in OOT_OPS, skipping.") diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index f0c07104..8a7c2697 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -160,48 +160,15 @@ def get_attn_backend_cls( use_mla = attn_selector_config.use_mla - try: - backend_path = call_op("attention_backend", use_mla=use_mla) - - logger.info_once( - "Using attention backend via dispatch (use_mla=%s): %s", - use_mla, - backend_path, - scope="local", - ) - return backend_path - - except RuntimeError as e: - # Fallback: if dispatch fails, use device-type based selection - logger.warning( - "Dispatch mechanism failed for attention_backend, " - "falling back to device-type based selection: %s", - e, - ) - - if cls.device_type == "npu": - if use_mla: - backend_path = ( - "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend" - ) - else: - backend_path = "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" - else: - # For CUDA and other devices, use vLLM native backend - from vllm.attention.backends.registry import AttentionBackendEnum - - if use_mla: - backend_path = AttentionBackendEnum.MLA.get_path() - else: - backend_path = AttentionBackendEnum.FLASH_ATTN.get_path() - - logger.info_once( - "Using fallback attention backend (use_mla=%s): %s", - use_mla, - backend_path, - scope="local", - ) - return backend_path + backend_path = call_op("attention_backend", use_mla=use_mla) + + logger.info_once( + "Using attention backend via dispatch (use_mla=%s): %s", + use_mla, + backend_path, + scope="local", + ) + return backend_path @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index 2f570010..30c6491d 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -23,25 +23,92 @@ def use_flaggems(default: bool = True) -> bool: return value.lower() in ("true", "1") -def use_flaggems_op(op_name: str, default: bool = True) -> bool: - if not use_flaggems(default=default): - return False +def get_flag_gems_whitelist_blacklist() -> Tuple[ + Optional[list[str]], Optional[list[str]] +]: + """ + Get FlagGems operator whitelist and blacklist. + + Priority (highest to lowest): + 1. VLLM_FL_FLAGOS_WHITELIST env var: Only these ops use FlagGems + 2. VLLM_FL_FLAGOS_BLACKLIST env var: These ops don't use FlagGems + 3. Platform config flaggems_blacklist: Default blacklist from config file + + Note: VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST cannot be set + simultaneously. If whitelist is set, it completely overrides any blacklist. + + Returns: + Tuple[Optional[list[str]], Optional[list[str]]]: + A tuple of (whitelist, blacklist). Each is None if not set. + + Raises: + ValueError: If both VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST + are set simultaneously. + """ whitelist_str = os.environ.get("VLLM_FL_FLAGOS_WHITELIST", "") blacklist_str = os.environ.get("VLLM_FL_FLAGOS_BLACKLIST", "") - if not whitelist_str and not blacklist_str: - return True - whitelist = {op.strip() for op in whitelist_str.split(",") if op.strip()} - blacklist = {op.strip() for op in blacklist_str.split(",") if op.strip()} - if op_name in whitelist and op_name in blacklist: + + # Check if both env vars are set (conflict) + if whitelist_str and blacklist_str: raise ValueError( - "VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST both contain " - f"{op_name!r}. Please remove the conflict." + "Cannot set both VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST " + "simultaneously. Please set only one of them." ) - if op_name in blacklist: + + whitelist = None + blacklist = None + + # Priority 1: Whitelist from env var (completely overrides blacklist) + if whitelist_str: + whitelist = [op.strip() for op in whitelist_str.split(",") if op.strip()] + return whitelist, None # Whitelist overrides any blacklist + + # Priority 2: Blacklist from env var + if blacklist_str: + blacklist = [op.strip() for op in blacklist_str.split(",") if op.strip()] + return None, blacklist + + # Priority 3: Blacklist from platform config + try: + from vllm_fl.dispatch.config import get_flaggems_blacklist + config_blacklist = get_flaggems_blacklist() + if config_blacklist: + blacklist = config_blacklist + except Exception: + pass + + return whitelist, blacklist + + +def use_flaggems_op(op_name: str, default: bool = True) -> bool: + """ + Check if FlagGems should be used for a specific operator. + + Priority (highest to lowest): + 1. VLLM_FL_FLAGOS_WHITELIST env var: Only these ops use FlagGems + 2. VLLM_FL_FLAGOS_BLACKLIST env var: These ops don't use FlagGems + 3. Platform config flaggems_blacklist: Default blacklist from config file + 4. Default: Use FlagGems for all ops + + Note: Whitelist and blacklist (env vars) cannot be set simultaneously. + If whitelist is set, it completely overrides the config file blacklist. + """ + if not use_flaggems(default=default): return False - if not whitelist: - return True - return op_name in whitelist + + # Get whitelist/blacklist with proper priority + whitelist, blacklist = get_flag_gems_whitelist_blacklist() + + # If whitelist is set, only allow ops in whitelist + if whitelist is not None: + return op_name in whitelist + + # If blacklist is set (from env or config), deny ops in blacklist + if blacklist is not None: + return op_name not in blacklist + + # Default: allow all ops + return True def _load_op_config_from_env() -> None: @@ -108,62 +175,6 @@ def get_supported_device(self): return True -def get_flag_gems_whitelist_blacklist() -> Tuple[ - Optional[list[str]], Optional[list[str]] -]: - """ - Get FlagGems operator whitelist and blacklist from environment variables. - - Reads VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST environment variables, - parses comma-separated operator names, and returns them as lists. - - Note: VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST cannot be set simultaneously. - If both are set, a ValueError will be raised. - - Returns: - Tuple[Optional[list[str]], Optional[list[str]]]: - A tuple of (whitelist, blacklist). Each is None if not set, - or a list of operator names (stripped of whitespace) if set. - - Raises: - ValueError: If both VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST - are set simultaneously. - - Example: - >>> # Set whitelist only: - >>> # export VLLM_FL_FLAGOS_WHITELIST="silu_and_mul,rms_norm" - >>> whitelist, blacklist = get_flag_gems_whitelist_blacklist() - >>> # whitelist: ["silu_and_mul", "rms_norm"] - >>> # blacklist: None - - >>> # Set blacklist only: - >>> # export VLLM_FL_FLAGOS_BLACKLIST="index,index_put_" - >>> whitelist, blacklist = get_flag_gems_whitelist_blacklist() - >>> # whitelist: None - >>> # blacklist: ["index", "index_put_"] - """ - whitelist_str = os.environ.get("VLLM_FL_FLAGOS_WHITELIST", "") - blacklist_str = os.environ.get("VLLM_FL_FLAGOS_BLACKLIST", "") - - # Check if both are set - if whitelist_str and blacklist_str: - raise ValueError( - "Cannot set both VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST " - "simultaneously. Please set only one of them." - ) - - whitelist = None - blacklist = None - - if whitelist_str: - whitelist = [op.strip() for op in whitelist_str.split(",") if op.strip()] - - if blacklist_str: - blacklist = [op.strip() for op in blacklist_str.split(",") if op.strip()] - - return whitelist, blacklist - - def get_flaggems_all_ops() -> list[str]: """ Get all FlagGems operator names from flag_gems._FULL_CONFIG. @@ -203,6 +214,40 @@ def get_oot_whitelist() -> Optional[list[str]]: return [op.strip() for op in whitelist_str.split(",") if op.strip()] +def get_oot_blacklist() -> Optional[list[str]]: + """ + Get OOT operator blacklist from environment variable or platform config. + + Priority (highest to lowest): + 1. VLLM_FL_OOT_WHITELIST env var: If set, blacklist is ignored + 2. VLLM_FL_OOT_BLACKLIST env var: These ops won't be registered + 3. Platform config oot_blacklist: Default blacklist from config file + + Returns: + List of OOT operator names to NOT register, or None if not set. + """ + # If whitelist is set, blacklist is ignored + whitelist_str = os.environ.get("VLLM_FL_OOT_WHITELIST", "") + if whitelist_str: + return None + + # Priority 2: Blacklist from env var + blacklist_str = os.environ.get("VLLM_FL_OOT_BLACKLIST", "") + if blacklist_str: + return [op.strip() for op in blacklist_str.split(",") if op.strip()] + + # Priority 3: Blacklist from platform config + try: + from vllm_fl.dispatch.config import get_oot_blacklist as config_get_oot_blacklist + config_blacklist = config_get_oot_blacklist() + if config_blacklist: + return config_blacklist + except Exception: + pass + + return None + + def is_oot_enabled() -> bool: """ Check if OOT registration is enabled. From 97f6e39503a6138f2b08c78f4f12e3b041b42013 Mon Sep 17 00:00:00 2001 From: xin2an Date: Wed, 4 Feb 2026 18:59:26 +0800 Subject: [PATCH 21/34] Update copyright year from 2025 to 2026 --- vllm_fl/dispatch/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py index 578bee08..a999a5e4 100644 --- a/vllm_fl/dispatch/config/__init__.py +++ b/vllm_fl/dispatch/config/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 BAAI. All rights reserved. +# Copyright (c) 2026 BAAI. All rights reserved. """ Hardware-specific operator configuration loader. From 0fb006a556fe9179086ea45f338e1f732f343096 Mon Sep 17 00:00:00 2001 From: yxa Date: Thu, 5 Feb 2026 02:46:53 +0000 Subject: [PATCH 22/34] flaggems_blacklist > flagos_blacklist, add utils.py file --- vllm_fl/dispatch/README.md | 10 +- vllm_fl/dispatch/config/__init__.py | 215 ++-------------------------- vllm_fl/dispatch/config/ascend.yaml | 6 +- vllm_fl/dispatch/config/cuda.yaml | 4 +- vllm_fl/dispatch/config/utils.py | 210 +++++++++++++++++++++++++++ vllm_fl/utils.py | 8 +- 6 files changed, 235 insertions(+), 218 deletions(-) create mode 100644 vllm_fl/dispatch/config/utils.py diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 3cbc246a..6a382901 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -280,7 +280,7 @@ op_backends: # FlagGems operator blacklist (optional) # These operators will NOT use FlagGems implementation -flaggems_blacklist: +flagos_blacklist: - to_copy - zeros - mm @@ -334,7 +334,7 @@ Environment variables can override specific items from platform config. If not s | `VLLM_FL_FLAGOS_WHITELIST` | (none) | FlagGems ops whitelist (mutually exclusive with blacklist) | | `VLLM_FL_FLAGOS_BLACKLIST` | (none) | FlagGems ops blacklist (mutually exclusive with whitelist) | -**Priority**: `WHITELIST` > `BLACKLIST` (env) > `flaggems_blacklist` (config file) +**Priority**: `WHITELIST` > `BLACKLIST` (env) > `flagos_blacklist` (config file) #### OOT Operator Control @@ -455,14 +455,14 @@ WHITELIST (env) ──▶ Completely overrides blacklist ```bash # Platform config (ascend.yaml) has: # prefer: flagos -# flaggems_blacklist: [to_copy, zeros, mm, ...] +# flagos_blacklist: [to_copy, zeros, mm, ...] # User overrides only prefer, blacklist still from config export VLLM_FL_PREFER=vendor # Result: # prefer: vendor (from env) -# flaggems_blacklist: [to_copy, zeros, mm, ...] (from config) +# flagos_blacklist: [to_copy, zeros, mm, ...] (from config) ``` ```bash @@ -472,7 +472,7 @@ export VLLM_FL_FLAGOS_BLACKLIST="custom_op1,custom_op2" # Result: # prefer: vendor (from env) -# flaggems_blacklist: [custom_op1, custom_op2] (from env, config ignored) +# flagos_blacklist: [custom_op1, custom_op2] (from env, config ignored) ``` #### Important Notes diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py index a999a5e4..c60b1ce8 100644 --- a/vllm_fl/dispatch/config/__init__.py +++ b/vllm_fl/dispatch/config/__init__.py @@ -4,218 +4,25 @@ Hardware-specific operator configuration loader. This module provides automatic loading of operator configurations based on -the detected hardware platform. - -Configuration Priority (highest to lowest): -1. VLLM_FL_CONFIG: User-specified config file path (complete override) -2. Environment variables: Override specific items from platform config - - VLLM_FL_PREFER: Backend preference (flagos, vendor, reference) - - VLLM_FL_STRICT: Strict mode (1 or 0) - - VLLM_FL_PER_OP: Per-operator backend order - - VLLM_FL_FLAGOS_BLACKLIST: FlagGems operator blacklist - - VLLM_FL_OOT_BLACKLIST: OOT operator blacklist -3. Platform-specific config file: Default values (auto-detected) -4. Built-in default values - -Supported platforms: -- ascend: Huawei Ascend NPU -- cuda: NVIDIA GPU -- (more platforms can be added) - -Configuration files are stored in this directory as YAML files: -- ascend.yaml -- cuda.yaml +the detected hardware platform. See utils.py for implementation details. """ -from __future__ import annotations - -import os -from pathlib import Path -from typing import Any, Optional - -import yaml - -# Directory containing this file (config/) -_CONFIG_DIR = Path(__file__).parent - - -def get_platform_name() -> str: - """ - Detect the current hardware platform. - - Returns: - Platform name string: 'ascend', 'cuda', or 'unknown' - """ - try: - import torch - if hasattr(torch, 'npu') and torch.npu.is_available(): - return 'ascend' - if torch.cuda.is_available(): - return 'cuda' - except ImportError: - pass - - # Check environment variable override - platform_override = os.environ.get('VLLM_FL_PLATFORM', '').strip().lower() - if platform_override: - return platform_override - - return 'unknown' - - -def get_config_path(platform: Optional[str] = None) -> Optional[Path]: - """ - Get the configuration file path for the specified or detected platform. - - Args: - platform: Platform name. If None, auto-detect. - - Returns: - Path to the config file, or None if not found. - """ - if platform is None: - platform = get_platform_name() - - # Try platform-specific config - config_file = _CONFIG_DIR / f"{platform}.yaml" - if config_file.exists(): - return config_file - - return None - - -def load_platform_config(platform: Optional[str] = None) -> Optional[dict[str, Any]]: - """ - Load the configuration for the specified or detected platform. - - Args: - platform: Platform name. If None, auto-detect. - - Returns: - Configuration dictionary, or None if no config found. - """ - config_path = get_config_path(platform) - if config_path is None: - return None - - try: - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - return config if isinstance(config, dict) else None - except Exception: - return None - - -def get_per_op_order(config: Optional[dict] = None) -> Optional[dict[str, list[str]]]: - """ - Extract per-op backend order from config. - - Args: - config: Configuration dict. If None, load from platform config. - - Returns: - Dict mapping op names to backend order lists. - """ - if config is None: - config = load_platform_config() - if config is None: - return None - - per_op = config.get('per_op', {}) - if not isinstance(per_op, dict): - return None - - result = {} - for op_name, backends in per_op.items(): - if isinstance(backends, list): - result[op_name] = [str(b) for b in backends] - elif isinstance(backends, str): - result[op_name] = [backends] - - return result if result else None - - -def get_flaggems_blacklist(config: Optional[dict] = None) -> Optional[list[str]]: - """ - Extract FlagGems operator blacklist from config. - - Args: - config: Configuration dict. If None, load from platform config. - - Returns: - List of blacklisted FlagGems operator names. - """ - if config is None: - config = load_platform_config() - if config is None: - return None - - blacklist = config.get('flaggems_blacklist', []) - if isinstance(blacklist, list): - return [str(op) for op in blacklist] - return None - - -def get_oot_blacklist(config: Optional[dict] = None) -> Optional[list[str]]: - """ - Extract OOT operator blacklist from config. - - Args: - config: Configuration dict. If None, load from platform config. - - Returns: - List of blacklisted OOT operator names. - """ - if config is None: - config = load_platform_config() - if config is None: - return None - - blacklist = config.get('oot_blacklist', []) - if isinstance(blacklist, list): - return [str(op) for op in blacklist] - return None - - -def get_effective_config() -> dict[str, Any]: - """ - Get the effective configuration, considering environment variable overrides. - - Priority: - 1. VLLM_FL_CONFIG environment variable (user-specified config file) - 2. Platform-specific config file (auto-detected) - 3. Default config file - 4. Empty config (no restrictions) - - Returns: - Effective configuration dictionary. - """ - # Check for user-specified config file - user_config_path = os.environ.get('VLLM_FL_CONFIG', '').strip() - if user_config_path and os.path.isfile(user_config_path): - try: - with open(user_config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - if isinstance(config, dict): - return config - except Exception: - pass - - # Load platform config - platform_config = load_platform_config() - if platform_config: - return platform_config - - # Return empty config - return {} - +from vllm_fl.dispatch.config.utils import ( + get_config_path, + get_effective_config, + get_flagos_blacklist, + get_oot_blacklist, + get_per_op_order, + get_platform_name, + load_platform_config, +) __all__ = [ 'get_platform_name', 'get_config_path', 'load_platform_config', 'get_per_op_order', - 'get_flaggems_blacklist', + 'get_flagos_blacklist', 'get_oot_blacklist', 'get_effective_config', ] diff --git a/vllm_fl/dispatch/config/ascend.yaml b/vllm_fl/dispatch/config/ascend.yaml index 5d7f7b54..2caa51f2 100644 --- a/vllm_fl/dispatch/config/ascend.yaml +++ b/vllm_fl/dispatch/config/ascend.yaml @@ -51,9 +51,9 @@ op_backends: - flagos - reference -# FlagGems operator blacklist -# These operators will NOT use FlagGems implementation even if FlagGems is enabled. -flaggems_blacklist: +# FlagOS operator blacklist +# These operators will NOT use FlagOS implementation even if FlagOS is enabled. +flagos_blacklist: - to_copy - _to_copy - zeros diff --git a/vllm_fl/dispatch/config/cuda.yaml b/vllm_fl/dispatch/config/cuda.yaml index 6ad2d6dc..3c8676ff 100644 --- a/vllm_fl/dispatch/config/cuda.yaml +++ b/vllm_fl/dispatch/config/cuda.yaml @@ -50,8 +50,8 @@ op_backends: - vendor - reference -# FlagGems operator blacklist -# flaggems_blacklist: [] +# FlagOS operator blacklist +# flagos_blacklist: [] # OOT (Out-of-Tree) operator blacklist # oot_blacklist: [] diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py new file mode 100644 index 00000000..a0e41db6 --- /dev/null +++ b/vllm_fl/dispatch/config/utils.py @@ -0,0 +1,210 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Hardware-specific operator configuration loader utilities. + +This module provides automatic loading of operator configurations based on +the detected hardware platform. + +Configuration Priority (highest to lowest): +1. VLLM_FL_CONFIG: User-specified config file path (complete override) +2. Environment variables: Override specific items from platform config + - VLLM_FL_PREFER: Backend preference (flagos, vendor, reference) + - VLLM_FL_STRICT: Strict mode (1 or 0) + - VLLM_FL_PER_OP: Per-operator backend order + - VLLM_FL_FLAGOS_BLACKLIST: FlagOS operator blacklist + - VLLM_FL_OOT_BLACKLIST: OOT operator blacklist +3. Platform-specific config file: Default values (auto-detected) +4. Built-in default values + +Supported platforms: +- ascend: Huawei Ascend NPU +- cuda: NVIDIA GPU +- (more platforms can be added) + +Configuration files are stored in this directory as YAML files: +- ascend.yaml +- cuda.yaml +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Optional + +import yaml + +# Directory containing config files (config/) +_CONFIG_DIR = Path(__file__).parent + + +def get_platform_name() -> str: + """ + Detect the current hardware platform. + + Returns: + Platform name string: 'ascend', 'cuda', or 'unknown' + """ + try: + import torch + if hasattr(torch, 'npu') and torch.npu.is_available(): + return 'ascend' + if torch.cuda.is_available(): + return 'cuda' + except ImportError: + pass + + # Check environment variable override + platform_override = os.environ.get('VLLM_FL_PLATFORM', '').strip().lower() + if platform_override: + return platform_override + + return 'unknown' + + +def get_config_path(platform: Optional[str] = None) -> Optional[Path]: + """ + Get the configuration file path for the specified or detected platform. + + Args: + platform: Platform name. If None, auto-detect. + + Returns: + Path to the config file, or None if not found. + """ + if platform is None: + platform = get_platform_name() + + # Try platform-specific config + config_file = _CONFIG_DIR / f"{platform}.yaml" + if config_file.exists(): + return config_file + + return None + + +def load_platform_config(platform: Optional[str] = None) -> Optional[dict[str, Any]]: + """ + Load the configuration for the specified or detected platform. + + Args: + platform: Platform name. If None, auto-detect. + + Returns: + Configuration dictionary, or None if no config found. + """ + config_path = get_config_path(platform) + if config_path is None: + return None + + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config if isinstance(config, dict) else None + except Exception: + return None + + +def get_per_op_order(config: Optional[dict] = None) -> Optional[dict[str, list[str]]]: + """ + Extract per-op backend order from config. + + Args: + config: Configuration dict. If None, load from platform config. + + Returns: + Dict mapping op names to backend order lists. + """ + if config is None: + config = load_platform_config() + if config is None: + return None + + per_op = config.get('per_op', {}) + if not isinstance(per_op, dict): + return None + + result = {} + for op_name, backends in per_op.items(): + if isinstance(backends, list): + result[op_name] = [str(b) for b in backends] + elif isinstance(backends, str): + result[op_name] = [backends] + + return result if result else None + + +def get_flagos_blacklist(config: Optional[dict] = None) -> Optional[list[str]]: + """ + Extract FlagOS operator blacklist from config. + + Args: + config: Configuration dict. If None, load from platform config. + + Returns: + List of blacklisted FlagOS operator names. + """ + if config is None: + config = load_platform_config() + if config is None: + return None + + blacklist = config.get('flagos_blacklist', []) + if isinstance(blacklist, list): + return [str(op) for op in blacklist] + return None + + +def get_oot_blacklist(config: Optional[dict] = None) -> Optional[list[str]]: + """ + Extract OOT operator blacklist from config. + + Args: + config: Configuration dict. If None, load from platform config. + + Returns: + List of blacklisted OOT operator names. + """ + if config is None: + config = load_platform_config() + if config is None: + return None + + blacklist = config.get('oot_blacklist', []) + if isinstance(blacklist, list): + return [str(op) for op in blacklist] + return None + + +def get_effective_config() -> dict[str, Any]: + """ + Get the effective configuration, considering environment variable overrides. + + Priority: + 1. VLLM_FL_CONFIG environment variable (user-specified config file) + 2. Platform-specific config file (auto-detected) + 3. Default config file + 4. Empty config (no restrictions) + + Returns: + Effective configuration dictionary. + """ + # Check for user-specified config file + user_config_path = os.environ.get('VLLM_FL_CONFIG', '').strip() + if user_config_path and os.path.isfile(user_config_path): + try: + with open(user_config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + if isinstance(config, dict): + return config + except Exception: + pass + + # Load platform config + platform_config = load_platform_config() + if platform_config: + return platform_config + + # Return empty config + return {} diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index 30c6491d..87bca810 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -32,7 +32,7 @@ def get_flag_gems_whitelist_blacklist() -> Tuple[ Priority (highest to lowest): 1. VLLM_FL_FLAGOS_WHITELIST env var: Only these ops use FlagGems 2. VLLM_FL_FLAGOS_BLACKLIST env var: These ops don't use FlagGems - 3. Platform config flaggems_blacklist: Default blacklist from config file + 3. Platform config flagos_blacklist: Default blacklist from config file Note: VLLM_FL_FLAGOS_WHITELIST and VLLM_FL_FLAGOS_BLACKLIST cannot be set simultaneously. If whitelist is set, it completely overrides any blacklist. @@ -70,8 +70,8 @@ def get_flag_gems_whitelist_blacklist() -> Tuple[ # Priority 3: Blacklist from platform config try: - from vllm_fl.dispatch.config import get_flaggems_blacklist - config_blacklist = get_flaggems_blacklist() + from vllm_fl.dispatch.config import get_flagos_blacklist + config_blacklist = get_flagos_blacklist() if config_blacklist: blacklist = config_blacklist except Exception: @@ -87,7 +87,7 @@ def use_flaggems_op(op_name: str, default: bool = True) -> bool: Priority (highest to lowest): 1. VLLM_FL_FLAGOS_WHITELIST env var: Only these ops use FlagGems 2. VLLM_FL_FLAGOS_BLACKLIST env var: These ops don't use FlagGems - 3. Platform config flaggems_blacklist: Default blacklist from config file + 3. Platform config flagos_blacklist: Default blacklist from config file 4. Default: Use FlagGems for all ops Note: Whitelist and blacklist (env vars) cannot be set simultaneously. From 57f596d3adca277cfb69a79d484e4d58c5c782eb Mon Sep 17 00:00:00 2001 From: yxa Date: Thu, 5 Feb 2026 02:57:34 +0000 Subject: [PATCH 23/34] Delete TRITON_ATTN --- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 3a31c06b..ee698b43 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -117,8 +117,5 @@ def attention_backend(self, use_mla: bool = False) -> str: if use_mla: return AttentionBackendEnum.MLA.get_path() - # Use TRITON_ATTN when use_flaggems_op allows (e.g. USE_FLAGGEMS=1 / whitelist) - if use_flaggems_op("triton_attn"): - return AttentionBackendEnum.TRITON_ATTN.get_path() # Default to FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN.get_path() From 511c8c511ab4f7838f620b4033f1837ac57114b9 Mon Sep 17 00:00:00 2001 From: xin2an Date: Tue, 10 Feb 2026 15:15:14 +0800 Subject: [PATCH 24/34] Fix CUDA backend vendor detection to exclude CUDA-alike devices (MACA, MUSA, etc.) --- .../dispatch/backends/flaggems/flaggems.py | 21 +++--- .../backends/flaggems/impl/activation.py | 3 +- .../backends/flaggems/impl/normalization.py | 12 ++-- .../dispatch/backends/flaggems/impl/rotary.py | 2 + .../backends/reference/impl/activation.py | 3 +- .../backends/reference/impl/normalization.py | 12 ++-- .../backends/reference/impl/rotary.py | 2 + .../dispatch/backends/reference/reference.py | 21 +++--- .../dispatch/backends/vendor/ascend/ascend.py | 21 +++--- .../backends/vendor/ascend/impl/activation.py | 3 +- .../vendor/ascend/impl/normalization.py | 12 ++-- .../backends/vendor/ascend/impl/rotary.py | 2 + vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 69 ++++++++++++++++--- .../backends/vendor/cuda/impl/activation.py | 3 +- .../vendor/cuda/impl/normalization.py | 12 ++-- .../backends/vendor/cuda/impl/rotary.py | 2 + vllm_fl/dispatch/config/cuda.yaml | 38 +++++++++- vllm_fl/ops/activation.py | 2 +- vllm_fl/ops/layernorm.py | 2 +- vllm_fl/ops/rotary_embedding.py | 1 + 20 files changed, 180 insertions(+), 63 deletions(-) diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 291b6550..53ba9f00 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -42,11 +42,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -54,33 +55,32 @@ def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_flaggems - return silu_and_mul_flaggems(x) + return silu_and_mul_flaggems(instance, x) def rms_norm( self, + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ from .impl.normalization import rms_norm_flaggems - return rms_norm_flaggems(x, residual, weight, epsilon) + return rms_norm_flaggems(instance, x, residual) def rotary_embedding( self, + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -93,6 +93,7 @@ def rotary_embedding( Apply rotary position embedding. Args: + instance: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -107,6 +108,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_flaggems return rotary_embedding_flaggems( + instance, query, key, cos, @@ -116,11 +118,12 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, use_mla: bool = False) -> str: + def attention_backend(self, instance, use_mla: bool = False) -> str: """ Get the attention backend class path for FlagGems. Args: + instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/flaggems/impl/activation.py b/vllm_fl/dispatch/backends/flaggems/impl/activation.py index 96672819..db226f5b 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/activation.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/activation.py @@ -9,11 +9,12 @@ import torch -def silu_and_mul_flaggems(x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_flaggems(instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using FlagGems. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py index eb0927ad..2b09df43 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py @@ -12,23 +12,25 @@ def rms_norm_flaggems( + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization using FlagGems. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ from flag_gems.modules.normalization import gems_rms_forward + # Get weight and epsilon from instance + weight = instance.weight + epsilon = instance.variance_epsilon + return gems_rms_forward(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py index e7c8ddd8..138c143a 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py @@ -10,6 +10,7 @@ def rotary_embedding_flaggems( + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -22,6 +23,7 @@ def rotary_embedding_flaggems( Apply rotary position embedding using FlagGems. Args: + instance: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/reference/impl/activation.py b/vllm_fl/dispatch/backends/reference/impl/activation.py index 6dd89289..a3faa5f8 100644 --- a/vllm_fl/dispatch/backends/reference/impl/activation.py +++ b/vllm_fl/dispatch/backends/reference/impl/activation.py @@ -10,11 +10,12 @@ import torch.nn.functional as F -def silu_and_mul_torch(x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_torch(instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using PyTorch. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/reference/impl/normalization.py b/vllm_fl/dispatch/backends/reference/impl/normalization.py index 130fdaab..7ff3c8cd 100644 --- a/vllm_fl/dispatch/backends/reference/impl/normalization.py +++ b/vllm_fl/dispatch/backends/reference/impl/normalization.py @@ -12,23 +12,25 @@ def rms_norm_torch( + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization using PyTorch. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ + # Get weight and epsilon from instance + weight = instance.weight + epsilon = instance.variance_epsilon + if residual is not None: x = x + residual residual = x diff --git a/vllm_fl/dispatch/backends/reference/impl/rotary.py b/vllm_fl/dispatch/backends/reference/impl/rotary.py index 64afb0ce..707585a7 100644 --- a/vllm_fl/dispatch/backends/reference/impl/rotary.py +++ b/vllm_fl/dispatch/backends/reference/impl/rotary.py @@ -10,6 +10,7 @@ def rotary_embedding_torch( + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -22,6 +23,7 @@ def rotary_embedding_torch( Apply rotary position embedding using PyTorch. Args: + instance: The calling instance (for interface consistency) query: Query tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] key: Key tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] cos: Cosine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 69b87a5c..228b09f0 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -44,11 +44,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -56,33 +57,32 @@ def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_torch - return silu_and_mul_torch(x) + return silu_and_mul_torch(instance, x) def rms_norm( self, + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ from .impl.normalization import rms_norm_torch - return rms_norm_torch(x, residual, weight, epsilon) + return rms_norm_torch(instance, x, residual) def rotary_embedding( self, + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -95,6 +95,7 @@ def rotary_embedding( Apply rotary position embedding. Args: + instance: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -109,6 +110,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_torch return rotary_embedding_torch( + instance, query, key, cos, @@ -118,7 +120,7 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, use_mla: bool = False) -> str: + def attention_backend(self, instance, use_mla: bool = False) -> str: """ Get the attention backend class path for reference (vLLM native). @@ -126,6 +128,7 @@ def attention_backend(self, use_mla: bool = False) -> str: which serves as a fallback implementation. Args: + instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index 8c41222e..fc838bd9 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -51,11 +51,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -63,33 +64,32 @@ def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_ascend - return silu_and_mul_ascend(x) + return silu_and_mul_ascend(instance, x) def rms_norm( self, + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ from .impl.normalization import rms_norm_ascend - return rms_norm_ascend(x, residual, weight, epsilon) + return rms_norm_ascend(instance, x, residual) def rotary_embedding( self, + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -102,6 +102,7 @@ def rotary_embedding( Apply rotary position embedding. Args: + instance: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -116,6 +117,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_ascend return rotary_embedding_ascend( + instance, query, key, cos, @@ -125,7 +127,7 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, use_mla: bool = False) -> str: + def attention_backend(self, instance, use_mla: bool = False) -> str: """ Get the attention backend class path for Ascend NPU. @@ -137,6 +139,7 @@ def attention_backend(self, use_mla: bool = False) -> str: torch_npu operators without depending on vllm-ascend package. Args: + instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py index 320b04c1..26bb1373 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py @@ -9,11 +9,12 @@ import torch -def silu_and_mul_ascend(x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_ascend(instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using Ascend NPU. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py index d99a5143..b0b2a940 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -12,25 +12,27 @@ def rms_norm_ascend( + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization using Ascend NPU. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ import torch_npu + # Get weight and epsilon from instance + weight = instance.weight + epsilon = instance.variance_epsilon + if residual is not None: x, _, residual = torch_npu.npu_add_rms_norm(x, residual, weight, epsilon) return x, residual diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py index d92cae19..ce856daa 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -11,6 +11,7 @@ def rotary_embedding_ascend( + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,6 +24,7 @@ def rotary_embedding_ascend( Apply rotary position embedding using Ascend NPU. Args: + instance: The calling instance (for interface consistency) query: Query tensor [num_tokens, num_heads, rotary_dim] key: Key tensor [num_tokens, num_kv_heads, rotary_dim] cos: Cosine cache [max_seq_len, rotary_dim // 2] diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index a3fdfb88..3665c206 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -31,49 +31,85 @@ def name(self) -> str: @property def vendor(self) -> Optional[str]: - return "cuda" + return "nvidia" def is_available(self) -> bool: """Check if CUDA hardware and libraries are available.""" if CudaBackend._available is None: try: # Check if CUDA device is available - if torch.cuda.is_available() and torch.cuda.device_count() > 0: + if not torch.cuda.is_available() or torch.cuda.device_count() == 0: + CudaBackend._available = False + return False + + # Check if this is a real NVIDIA GPU (not CUDA-alike hardware) + # Check device name to exclude CUDA-alike vendors + device_name = torch.cuda.get_device_name(0).upper() + + # Exclude CUDA-alike vendors by device name + # Note: MACA is the device name (like ROCm), METAX is the vendor name + cuda_alike_device_names = ["MUSA", "MOORE", "MACA", "ILUVATAR", + "HYGON", "DCU", "KUNLUN", "CAMBRICON"] + for device_keyword in cuda_alike_device_names: + if device_keyword in device_name: + CudaBackend._available = False + return False + + # Verify it's NVIDIA or has CUDA in the name + if "NVIDIA" in device_name or "CUDA" in device_name: CudaBackend._available = True else: + # If device name doesn't contain NVIDIA or CUDA, + # it might be a CUDA-alike device CudaBackend._available = False + except Exception: CudaBackend._available = False return CudaBackend._available # ==================== Operator Implementations ==================== - def silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Uses vLLM's native CUDA implementation. + + Args: + instance: The calling instance (for interface consistency) + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] """ from .impl.activation import silu_and_mul_cuda - return silu_and_mul_cuda(x) + return silu_and_mul_cuda(instance, x) def rms_norm( self, + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization using vLLM's CUDA implementation. + + Args: + instance: The calling instance (e.g., RMSNorm layer) + x: Input tensor + residual: Optional residual tensor + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided """ from .impl.normalization import rms_norm_cuda - return rms_norm_cuda(x, residual, weight, epsilon) + return rms_norm_cuda(instance, x, residual) def rotary_embedding( self, + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -84,10 +120,24 @@ def rotary_embedding( ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary position embedding using vLLM's CUDA implementation. + + Args: + instance: The calling instance (for interface consistency) + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) """ from .impl.rotary import rotary_embedding_cuda return rotary_embedding_cuda( + instance, query, key, cos, @@ -97,7 +147,7 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, use_mla: bool = False) -> str: + def attention_backend(self, instance, use_mla: bool = False) -> str: """ Get the attention backend class path for CUDA. @@ -106,6 +156,7 @@ def attention_backend(self, use_mla: bool = False) -> str: - TRITON_ATTN (when use_flaggems_op("triton_attn") is True) Args: + instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index dc89b4c3..09e35fcd 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -9,13 +9,14 @@ import torch -def silu_and_mul_cuda(x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_cuda(instance, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using CUDA. Uses vLLM's optimized CUDA kernel when available. Args: + instance: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py index 7671e5be..f6e0f651 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py @@ -12,10 +12,9 @@ def rms_norm_cuda( + instance, x: torch.Tensor, - residual: Optional[torch.Tensor], - weight: torch.Tensor, - epsilon: float, + residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ RMS normalization using CUDA. @@ -23,10 +22,9 @@ def rms_norm_cuda( Uses vLLM's optimized CUDA kernel when available. Args: + instance: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization - weight: Normalization weight - epsilon: Small constant for numerical stability Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided @@ -34,6 +32,10 @@ def rms_norm_cuda( from vllm._custom_ops import rms_norm as vllm_rms_norm from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm + # Get weight and epsilon from instance + weight = instance.weight + epsilon = instance.variance_epsilon + if residual is not None: vllm_fused_add_rms_norm(x, residual, weight, epsilon) return x, residual diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py index 97711b62..f816a213 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py @@ -10,6 +10,7 @@ def rotary_embedding_cuda( + instance, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -24,6 +25,7 @@ def rotary_embedding_cuda( Uses vLLM's optimized CUDA kernel when available. Args: + instance: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/config/cuda.yaml b/vllm_fl/dispatch/config/cuda.yaml index 3c8676ff..eec16991 100644 --- a/vllm_fl/dispatch/config/cuda.yaml +++ b/vllm_fl/dispatch/config/cuda.yaml @@ -51,7 +51,43 @@ op_backends: - reference # FlagOS operator blacklist -# flagos_blacklist: [] +flagos_blacklist: + - to_copy + - _to_copy + - zeros + - copy_ + - fill_scalar_ + - sum_dim + - exponential_ + - mm + - resolve_neg + - resolve_conj + - eq_scalar + - floor_divide + - cumsum + - mul + - reciprocal + - repeat + - randn + - add + - ge_scalar + - sub + - bitwise_and + - bitwise_not + - slice_scatter + - fill_tensor_ + - conv1d + - conv2d + - uniform_ + - prod + - max + - amax + - cat + - stack + - flatten + - reshape + - view + - tensor # OOT (Out-of-Tree) operator blacklist # oot_blacklist: [] diff --git a/vllm_fl/ops/activation.py b/vllm_fl/ops/activation.py index 049a983e..e895c5b0 100644 --- a/vllm_fl/ops/activation.py +++ b/vllm_fl/ops/activation.py @@ -10,6 +10,6 @@ def __init__(self): super().__init__() def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - return call_op("silu_and_mul", x) + return call_op("silu_and_mul", self, x) __all__ = ["SiluAndMulFL"] diff --git a/vllm_fl/ops/layernorm.py b/vllm_fl/ops/layernorm.py index 4cb82077..ea13c7c4 100644 --- a/vllm_fl/ops/layernorm.py +++ b/vllm_fl/ops/layernorm.py @@ -22,7 +22,7 @@ def forward_oot( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - return call_op("rms_norm", x, residual, self.weight, self.variance_epsilon) + return call_op("rms_norm", self, x, residual) __all__ = ["RMSNormFL"] diff --git a/vllm_fl/ops/rotary_embedding.py b/vllm_fl/ops/rotary_embedding.py index 5fa19cd6..98694bf9 100644 --- a/vllm_fl/ops/rotary_embedding.py +++ b/vllm_fl/ops/rotary_embedding.py @@ -46,6 +46,7 @@ def forward_oot( q_embed, k_embed = call_op( "rotary_embedding", + self, query_rot, key_rot, cos, From 2b1bfe9f020c70492f2c96f7ee302e2cf7d5617a Mon Sep 17 00:00:00 2001 From: xin2an Date: Tue, 10 Feb 2026 15:28:45 +0800 Subject: [PATCH 25/34] Fix attention_backend bug --- vllm_fl/dispatch/backends/flaggems/flaggems.py | 3 +-- vllm_fl/dispatch/backends/reference/reference.py | 3 +-- vllm_fl/dispatch/backends/vendor/ascend/ascend.py | 3 +-- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 53ba9f00..1beb8bda 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -118,12 +118,11 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, instance, use_mla: bool = False) -> str: + def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for FlagGems. Args: - instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 228b09f0..533f0f3a 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -120,7 +120,7 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, instance, use_mla: bool = False) -> str: + def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for reference (vLLM native). @@ -128,7 +128,6 @@ def attention_backend(self, instance, use_mla: bool = False) -> str: which serves as a fallback implementation. Args: - instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index fc838bd9..3a16037c 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -127,7 +127,7 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, instance, use_mla: bool = False) -> str: + def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for Ascend NPU. @@ -139,7 +139,6 @@ def attention_backend(self, instance, use_mla: bool = False) -> str: torch_npu operators without depending on vllm-ascend package. Args: - instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 3665c206..3d2f7ab0 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -147,7 +147,7 @@ def rotary_embedding( inplace=inplace, ) - def attention_backend(self, instance, use_mla: bool = False) -> str: + def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for CUDA. @@ -156,7 +156,6 @@ def attention_backend(self, instance, use_mla: bool = False) -> str: - TRITON_ATTN (when use_flaggems_op("triton_attn") is True) Args: - instance: The calling instance (for interface consistency) use_mla: Whether to use Multi-head Latent Attention (MLA) Returns: From cdf2443f5b3a8f6ef2d86f37a02223814d05069c Mon Sep 17 00:00:00 2001 From: xin2an Date: Tue, 10 Feb 2026 17:16:27 +0800 Subject: [PATCH 26/34] Modify CUDA backend vendor detection and add PTG configuration. --- vllm_fl/dispatch/auto_register.py | 183 ++++++++++++++++++ vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 46 +++-- vllm_fl/dispatch/config/cuda.yaml | 38 +--- vllm_fl/dispatch/config/ptg.yaml | 93 +++++++++ vllm_fl/platform.py | 1 + 5 files changed, 305 insertions(+), 56 deletions(-) create mode 100644 vllm_fl/dispatch/auto_register.py create mode 100644 vllm_fl/dispatch/config/ptg.yaml diff --git a/vllm_fl/dispatch/auto_register.py b/vllm_fl/dispatch/auto_register.py new file mode 100644 index 00000000..5c9ac45b --- /dev/null +++ b/vllm_fl/dispatch/auto_register.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Automatic operator registration utilities. + +This module provides utilities to automatically register operator implementations +from Backend classes without manually listing each operator in register_ops.py. +""" + +from __future__ import annotations + +import functools +import inspect +from typing import TYPE_CHECKING, List + +from .types import OpImpl, BackendImplKind, BackendPriority + +if TYPE_CHECKING: + from .backends.base import Backend + from .registry import OpRegistry + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def _is_operator_method(name: str, method) -> bool: + """ + Check if a method is an operator implementation. + + Excludes: + - Private methods (starting with _) + - Abstract methods + - Properties + - Class methods + - Static methods + - Backend infrastructure methods (is_available, name, vendor) + + Args: + name: Method name + method: Method object + + Returns: + True if this is an operator implementation method + """ + # Skip private methods + if name.startswith("_"): + return False + + # Skip infrastructure methods + if name in ("is_available", "name", "vendor"): + return False + + # Skip if not callable + if not callable(method): + return False + + # Skip properties, classmethods, staticmethods + if isinstance(inspect.getattr_static(method.__class__, name, None), + (property, classmethod, staticmethod)): + return False + + return True + + +def auto_register_backend( + backend: Backend, + registry: OpRegistry, + kind: BackendImplKind, + priority: int = None, + vendor: str = None, +) -> List[OpImpl]: + """ + Automatically register all operator implementations from a Backend instance. + + This function inspects the backend class and automatically creates OpImpl + registrations for all public methods (excluding infrastructure methods like + is_available, name, vendor). + + Args: + backend: Backend instance to register + registry: Registry to register into + kind: Backend implementation kind (DEFAULT, REFERENCE, VENDOR) + priority: Priority for selection (defaults based on kind) + vendor: Vendor name (required if kind is VENDOR) + + Returns: + List of registered OpImpl instances + + Example: + ```python + from vllm_fl.dispatch.auto_register import auto_register_backend + from vllm_fl.dispatch.types import BackendImplKind, BackendPriority + + def register_builtins(registry): + from .cuda import CudaBackend + + backend = CudaBackend() + auto_register_backend( + backend=backend, + registry=registry, + kind=BackendImplKind.VENDOR, + priority=BackendPriority.VENDOR, + vendor="cuda", + ) + ``` + """ + # Set default priority based on kind + if priority is None: + if kind == BackendImplKind.DEFAULT: + priority = BackendPriority.DEFAULT + elif kind == BackendImplKind.VENDOR: + priority = BackendPriority.VENDOR + elif kind == BackendImplKind.REFERENCE: + priority = BackendPriority.REFERENCE + else: + priority = 0 + + # Validate vendor for VENDOR kind + if kind == BackendImplKind.VENDOR and not vendor: + # Try to get vendor from backend + vendor = backend.vendor + if not vendor: + raise ValueError( + f"Backend kind is VENDOR but no vendor name provided. " + f"Either pass vendor parameter or implement backend.vendor property." + ) + + # Get backend name for impl_id + backend_name = backend.name + + # Collect all operator methods + impls = [] + is_avail = backend.is_available + + # Inspect backend class for operator methods + for name in dir(backend): + # Skip private and infrastructure methods + if name.startswith("_") or name in ("is_available", "name", "vendor"): + continue + + try: + method = getattr(backend, name) + except Exception: + continue + + # Check if this is an operator method + if not _is_operator_method(name, method): + continue + + # Create impl_id based on kind + if kind == BackendImplKind.VENDOR: + impl_id = f"vendor.{vendor}" + elif kind == BackendImplKind.DEFAULT: + impl_id = f"default.{backend_name}" + elif kind == BackendImplKind.REFERENCE: + impl_id = f"reference.{backend_name}" + else: + impl_id = f"{backend_name}.{name}" + + # Create OpImpl + impl = OpImpl( + op_name=name, + impl_id=impl_id, + kind=kind, + fn=_bind_is_available(method, is_avail), + vendor=vendor, + priority=priority, + ) + + impls.append(impl) + + # Register all implementations + registry.register_many(impls) + + return impls diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 3d2f7ab0..b6d76088 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -34,7 +34,13 @@ def vendor(self) -> Optional[str]: return "nvidia" def is_available(self) -> bool: - """Check if CUDA hardware and libraries are available.""" + """ + Check if CUDA hardware and libraries are available. + + This method uses the platform's vendor information from FlagGems + to determine if the device is a real NVIDIA GPU, decoupling from + CUDA-alike devices (MACA, MUSA, etc.) which have their own vendor names. + """ if CudaBackend._available is None: try: # Check if CUDA device is available @@ -42,26 +48,28 @@ def is_available(self) -> bool: CudaBackend._available = False return False - # Check if this is a real NVIDIA GPU (not CUDA-alike hardware) - # Check device name to exclude CUDA-alike vendors - device_name = torch.cuda.get_device_name(0).upper() + # Use current_platform's vendor information to check if this is NVIDIA + # This decouples from device name string matching and properly + # distinguishes NVIDIA GPUs from CUDA-alike devices + try: + from vllm.platforms import current_platform + + # Only enable CUDA backend for NVIDIA vendor + # CUDA-alike devices (MACA/metax, MUSA/mthreads, etc.) + # have their own vendor names and should use their own backends + if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "nvidia": + CudaBackend._available = True + else: + CudaBackend._available = False - # Exclude CUDA-alike vendors by device name - # Note: MACA is the device name (like ROCm), METAX is the vendor name - cuda_alike_device_names = ["MUSA", "MOORE", "MACA", "ILUVATAR", - "HYGON", "DCU", "KUNLUN", "CAMBRICON"] - for device_keyword in cuda_alike_device_names: - if device_keyword in device_name: + except Exception: + # Fallback: if platform detection fails, check device name + # This ensures backward compatibility + device_name = torch.cuda.get_device_name(0).upper() + if "NVIDIA" in device_name: + CudaBackend._available = True + else: CudaBackend._available = False - return False - - # Verify it's NVIDIA or has CUDA in the name - if "NVIDIA" in device_name or "CUDA" in device_name: - CudaBackend._available = True - else: - # If device name doesn't contain NVIDIA or CUDA, - # it might be a CUDA-alike device - CudaBackend._available = False except Exception: CudaBackend._available = False diff --git a/vllm_fl/dispatch/config/cuda.yaml b/vllm_fl/dispatch/config/cuda.yaml index eec16991..3c8676ff 100644 --- a/vllm_fl/dispatch/config/cuda.yaml +++ b/vllm_fl/dispatch/config/cuda.yaml @@ -51,43 +51,7 @@ op_backends: - reference # FlagOS operator blacklist -flagos_blacklist: - - to_copy - - _to_copy - - zeros - - copy_ - - fill_scalar_ - - sum_dim - - exponential_ - - mm - - resolve_neg - - resolve_conj - - eq_scalar - - floor_divide - - cumsum - - mul - - reciprocal - - repeat - - randn - - add - - ge_scalar - - sub - - bitwise_and - - bitwise_not - - slice_scatter - - fill_tensor_ - - conv1d - - conv2d - - uniform_ - - prod - - max - - amax - - cat - - stack - - flatten - - reshape - - view - - tensor +# flagos_blacklist: [] # OOT (Out-of-Tree) operator blacklist # oot_blacklist: [] diff --git a/vllm_fl/dispatch/config/ptg.yaml b/vllm_fl/dispatch/config/ptg.yaml new file mode 100644 index 00000000..eec16991 --- /dev/null +++ b/vllm_fl/dispatch/config/ptg.yaml @@ -0,0 +1,93 @@ +# vLLM-FL Dispatch Configuration for CUDA +# Auto-loaded when running on NVIDIA GPU hardware + +# Preferred default backend type: flaggems, vendor, reference +prefer: flagos + +# Strict Mode: +# true = Raise an error immediately on failure; do not attempt other backends. +# false = Attempt the next available backend in sequence upon failure (Default). +strict: false + +# Vendor Whitelist (Optional, allows all if not set) +# allow_vendors: +# - cuda + +# Vendor Blacklist (Optional) +# deny_vendors: +# - ascend + +# Per-operator backend execution order (Optional) +# Only the backends listed here will be attempted, in the order specified. +# +# Supported tokens: +# - flaggems : Default FlagGems implementation (Triton) +# - reference : PyTorch reference implementation +# - vendor : Any available vendor backend (auto-detected) +# - vendor:cuda : CUDA-specific vendor backend +op_backends: + # attention_backend: prioritize flaggems (Triton attention) + attention_backend: + - flagos + - vendor + - reference + + # rms_norm: prioritize flaggems (Triton) + rms_norm: + - flagos + - vendor + - reference + + # silu_and_mul: prioritize flaggems + silu_and_mul: + - flagos + - vendor + - reference + + # rotary_embedding: prioritize flaggems + rotary_embedding: + - flagos + - vendor + - reference + +# FlagOS operator blacklist +flagos_blacklist: + - to_copy + - _to_copy + - zeros + - copy_ + - fill_scalar_ + - sum_dim + - exponential_ + - mm + - resolve_neg + - resolve_conj + - eq_scalar + - floor_divide + - cumsum + - mul + - reciprocal + - repeat + - randn + - add + - ge_scalar + - sub + - bitwise_and + - bitwise_not + - slice_scatter + - fill_tensor_ + - conv1d + - conv2d + - uniform_ + - prod + - max + - amax + - cat + - stack + - flatten + - reshape + - view + - tensor + +# OOT (Out-of-Tree) operator blacklist +# oot_blacklist: [] diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 8a7c2697..0e6a577e 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -39,6 +39,7 @@ class PlatformFL(Platform): device_type = device_info.device_type dispatch_key = device_info.dispatch_key torch_device_fn = device_info.torch_device_fn + vendor_name = device_info.vendor_name ray_device_key: str = "flagos" dist_backend: str = "flagcx" if "FLAGCX_PATH" in os.environ else "nccl" ### TODO(lms): dispatch device_control_env_var From c759f50f8410a405c9cdd730edb83fcb79cbe817 Mon Sep 17 00:00:00 2001 From: xin2an Date: Tue, 10 Feb 2026 17:47:09 +0800 Subject: [PATCH 27/34] Add metax backend --- .../backends/vendor/metax/__init__.py | 9 + .../backends/vendor/metax/impl/__init__.py | 7 + .../backends/vendor/metax/impl/activation.py | 27 +++ .../vendor/metax/impl/normalization.py | 46 +++++ .../backends/vendor/metax/impl/rotary.py | 56 ++++++ .../dispatch/backends/vendor/metax/metax.py | 169 ++++++++++++++++++ .../backends/vendor/metax/register_ops.py | 78 ++++++++ vllm_fl/dispatch/config/metax.yaml | 55 ++++++ 8 files changed, 447 insertions(+) create mode 100644 vllm_fl/dispatch/backends/vendor/metax/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/metax/impl/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/metax/impl/activation.py create mode 100644 vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py create mode 100644 vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py create mode 100644 vllm_fl/dispatch/backends/vendor/metax/metax.py create mode 100644 vllm_fl/dispatch/backends/vendor/metax/register_ops.py create mode 100644 vllm_fl/dispatch/config/metax.yaml diff --git a/vllm_fl/dispatch/backends/vendor/metax/__init__.py b/vllm_fl/dispatch/backends/vendor/metax/__init__.py new file mode 100644 index 00000000..e2b14c53 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX (Moore Threads) backend for vllm-plugin-FL dispatch. +""" + +from .metax import MetaxBackend + +__all__ = ["MetaxBackend"] diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/__init__.py b/vllm_fl/dispatch/backends/vendor/metax/impl/__init__.py new file mode 100644 index 00000000..6d2aa4c0 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX implementation module. +""" + +__all__ = ["activation", "normalization", "rotary"] diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py b/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py new file mode 100644 index 00000000..5d531521 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX activation operator implementations. +""" + +from __future__ import annotations + +import torch + + +def silu_and_mul_metax(instance, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication using METAX/MACA. + + Args: + instance: The calling instance (for interface consistency) + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + # TODO: Implement METAX-specific optimized version + # For now, use PyTorch reference implementation + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + return torch.nn.functional.silu(x1) * x2 diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py new file mode 100644 index 00000000..d11beee0 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX normalization operator implementations. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + + +def rms_norm_metax( + instance, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using METAX/MACA. + + Args: + instance: The calling instance (e.g., RMSNorm layer) + x: Input tensor + residual: Optional residual tensor + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + # Get weight and epsilon from instance + weight = instance.weight + epsilon = instance.variance_epsilon + + # TODO: Implement METAX-specific optimized version + # For now, use PyTorch reference implementation + if residual is not None: + x = x + residual + residual = x + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + x = x * weight + + if residual is not None: + return x, residual + return x diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py new file mode 100644 index 00000000..ddbe5bd2 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX rotary embedding operator implementations. +""" + +from __future__ import annotations + +import torch + + +def rotary_embedding_metax( + instance, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using METAX/MACA. + + Args: + instance: The calling instance (for interface consistency) + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + # TODO: Implement METAX-specific optimized version + # For now, use PyTorch reference implementation + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # Gather cos and sin based on position_ids + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + return apply_rotary_pos_emb(query, key, cos, sin, position_ids) diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py new file mode 100644 index 00000000..57f5bcc7 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py @@ -0,0 +1,169 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX backend implementation. + +This backend provides operator implementations for Moore Threads METAX GPUs. +METAX uses MACA (Moore Threads Accelerated Computing Architecture) which is +CUDA-compatible. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +from vllm_fl.dispatch.backends.base import Backend + + +class MetaxBackend(Backend): + """ + METAX backend for operator implementations. + + This backend uses MACA libraries to provide high-performance + operator implementations for Moore Threads METAX GPUs. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "metax" + + @property + def vendor(self) -> Optional[str]: + return "metax" + + def is_available(self) -> bool: + """ + Check if METAX hardware and libraries are available. + + This method uses the platform's vendor information to determine + if the device is a METAX GPU. + """ + if MetaxBackend._available is None: + try: + # Check if CUDA device is available (MACA is CUDA-compatible) + if not torch.cuda.is_available() or torch.cuda.device_count() == 0: + MetaxBackend._available = False + return False + + # Use current_platform's vendor information to check if this is METAX + try: + from vllm.platforms import current_platform + + # Only enable METAX backend for metax vendor + if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "metax": + MetaxBackend._available = True + else: + MetaxBackend._available = False + + except Exception: + # Fallback: check device name for MACA/METAX keywords + device_name = torch.cuda.get_device_name(0).upper() + if "MACA" in device_name or "METAX" in device_name or "MOORE" in device_name: + MetaxBackend._available = True + else: + MetaxBackend._available = False + + except Exception: + MetaxBackend._available = False + return MetaxBackend._available + + # ==================== Operator Implementations ==================== + + def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Args: + instance: The calling instance (for interface consistency) + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + from .impl.activation import silu_and_mul_metax + + return silu_and_mul_metax(instance, x) + + def rms_norm( + self, + instance, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization. + + Args: + instance: The calling instance (e.g., RMSNorm layer) + x: Input tensor + residual: Optional residual tensor + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + from .impl.normalization import rms_norm_metax + + return rms_norm_metax(instance, x, residual) + + def rotary_embedding( + self, + instance, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding. + + Args: + instance: The calling instance (for interface consistency) + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + from .impl.rotary import rotary_embedding_metax + + return rotary_embedding_metax( + instance, + query, + key, + cos, + sin, + position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for METAX. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string + """ + from vllm.attention.backends.registry import AttentionBackendEnum + + if use_mla: + # TODO: Implement METAX MLA backend + return AttentionBackendEnum.FLASHMLA.get_path() + + # Default to FLASH_ATTN (MACA is CUDA-compatible) + return AttentionBackendEnum.FLASH_ATTN.get_path() diff --git a/vllm_fl/dispatch/backends/vendor/metax/register_ops.py b/vllm_fl/dispatch/backends/vendor/metax/register_ops.py new file mode 100644 index 00000000..20e725ae --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/metax/register_ops.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX backend operator registrations. + +This module registers all VENDOR (METAX) implementations. +""" + +from __future__ import annotations + +import functools + +from vllm_fl.dispatch.types import OpImpl, BackendImplKind, BackendPriority + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all METAX (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + from .metax import MetaxBackend + + backend = MetaxBackend() + is_avail = backend.is_available + + impls = [ + # Activation + OpImpl( + op_name="silu_and_mul", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + vendor="metax", + priority=BackendPriority.VENDOR, + ), + # Normalization + OpImpl( + op_name="rms_norm", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rms_norm, is_avail), + vendor="metax", + priority=BackendPriority.VENDOR, + ), + # Rotary Embedding + OpImpl( + op_name="rotary_embedding", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rotary_embedding, is_avail), + vendor="metax", + priority=BackendPriority.VENDOR, + ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="vendor.metax", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor="metax", + priority=BackendPriority.VENDOR, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/config/metax.yaml b/vllm_fl/dispatch/config/metax.yaml new file mode 100644 index 00000000..c9ebdca4 --- /dev/null +++ b/vllm_fl/dispatch/config/metax.yaml @@ -0,0 +1,55 @@ +# vLLM-FL Dispatch Configuration for METAX (Moore Threads) +# Auto-loaded when running on METAX GPU hardware + +# Preferred default backend type: flaggems, vendor, reference +prefer: vendor + +# Strict Mode: +# true = Raise an error immediately on failure; do not attempt other backends. +# false = Attempt the next available backend in sequence upon failure (Default). +strict: false + +# Vendor Whitelist (Optional, allows all if not set) +# allow_vendors: +# - metax + +# Vendor Blacklist (Optional) +# deny_vendors: +# - cuda + +# Per-operator backend execution order (Optional) +# Only the backends listed here will be attempted, in the order specified. +# +# Supported tokens: +# - flaggems : Default FlagGems implementation (Triton) +# - reference : PyTorch reference implementation +# - vendor : Any available vendor backend (auto-detected) +# - vendor:metax : METAX-specific vendor backend +op_backends: + # Prefer vendor implementation for all operators + silu_and_mul: + - vendor:metax + - flaggems + - reference + + rms_norm: + - vendor:metax + - flaggems + - reference + + rotary_embedding: + - vendor:metax + - flaggems + - reference + + attention_backend: + - vendor:metax + - flaggems + +# FlagOS operator blacklist (Optional) +# These operators will not use FlagGems implementation +# flagos_blacklist: [] + +# OOT operator blacklist (Optional) +# These OOT operators will not be registered +# oot_blacklist: [] From 467df01c3591a197824944e70d7b93a9982ddc28 Mon Sep 17 00:00:00 2001 From: xin2an Date: Tue, 10 Feb 2026 18:03:33 +0800 Subject: [PATCH 28/34] delete auto_register.py --- vllm_fl/dispatch/auto_register.py | 183 ------------------------------ 1 file changed, 183 deletions(-) delete mode 100644 vllm_fl/dispatch/auto_register.py diff --git a/vllm_fl/dispatch/auto_register.py b/vllm_fl/dispatch/auto_register.py deleted file mode 100644 index 5c9ac45b..00000000 --- a/vllm_fl/dispatch/auto_register.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2026 BAAI. All rights reserved. - -""" -Automatic operator registration utilities. - -This module provides utilities to automatically register operator implementations -from Backend classes without manually listing each operator in register_ops.py. -""" - -from __future__ import annotations - -import functools -import inspect -from typing import TYPE_CHECKING, List - -from .types import OpImpl, BackendImplKind, BackendPriority - -if TYPE_CHECKING: - from .backends.base import Backend - from .registry import OpRegistry - - -def _bind_is_available(fn, is_available_fn): - """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - return fn(*args, **kwargs) - - wrapper._is_available = is_available_fn - return wrapper - - -def _is_operator_method(name: str, method) -> bool: - """ - Check if a method is an operator implementation. - - Excludes: - - Private methods (starting with _) - - Abstract methods - - Properties - - Class methods - - Static methods - - Backend infrastructure methods (is_available, name, vendor) - - Args: - name: Method name - method: Method object - - Returns: - True if this is an operator implementation method - """ - # Skip private methods - if name.startswith("_"): - return False - - # Skip infrastructure methods - if name in ("is_available", "name", "vendor"): - return False - - # Skip if not callable - if not callable(method): - return False - - # Skip properties, classmethods, staticmethods - if isinstance(inspect.getattr_static(method.__class__, name, None), - (property, classmethod, staticmethod)): - return False - - return True - - -def auto_register_backend( - backend: Backend, - registry: OpRegistry, - kind: BackendImplKind, - priority: int = None, - vendor: str = None, -) -> List[OpImpl]: - """ - Automatically register all operator implementations from a Backend instance. - - This function inspects the backend class and automatically creates OpImpl - registrations for all public methods (excluding infrastructure methods like - is_available, name, vendor). - - Args: - backend: Backend instance to register - registry: Registry to register into - kind: Backend implementation kind (DEFAULT, REFERENCE, VENDOR) - priority: Priority for selection (defaults based on kind) - vendor: Vendor name (required if kind is VENDOR) - - Returns: - List of registered OpImpl instances - - Example: - ```python - from vllm_fl.dispatch.auto_register import auto_register_backend - from vllm_fl.dispatch.types import BackendImplKind, BackendPriority - - def register_builtins(registry): - from .cuda import CudaBackend - - backend = CudaBackend() - auto_register_backend( - backend=backend, - registry=registry, - kind=BackendImplKind.VENDOR, - priority=BackendPriority.VENDOR, - vendor="cuda", - ) - ``` - """ - # Set default priority based on kind - if priority is None: - if kind == BackendImplKind.DEFAULT: - priority = BackendPriority.DEFAULT - elif kind == BackendImplKind.VENDOR: - priority = BackendPriority.VENDOR - elif kind == BackendImplKind.REFERENCE: - priority = BackendPriority.REFERENCE - else: - priority = 0 - - # Validate vendor for VENDOR kind - if kind == BackendImplKind.VENDOR and not vendor: - # Try to get vendor from backend - vendor = backend.vendor - if not vendor: - raise ValueError( - f"Backend kind is VENDOR but no vendor name provided. " - f"Either pass vendor parameter or implement backend.vendor property." - ) - - # Get backend name for impl_id - backend_name = backend.name - - # Collect all operator methods - impls = [] - is_avail = backend.is_available - - # Inspect backend class for operator methods - for name in dir(backend): - # Skip private and infrastructure methods - if name.startswith("_") or name in ("is_available", "name", "vendor"): - continue - - try: - method = getattr(backend, name) - except Exception: - continue - - # Check if this is an operator method - if not _is_operator_method(name, method): - continue - - # Create impl_id based on kind - if kind == BackendImplKind.VENDOR: - impl_id = f"vendor.{vendor}" - elif kind == BackendImplKind.DEFAULT: - impl_id = f"default.{backend_name}" - elif kind == BackendImplKind.REFERENCE: - impl_id = f"reference.{backend_name}" - else: - impl_id = f"{backend_name}.{name}" - - # Create OpImpl - impl = OpImpl( - op_name=name, - impl_id=impl_id, - kind=kind, - fn=_bind_is_available(method, is_avail), - vendor=vendor, - priority=priority, - ) - - impls.append(impl) - - # Register all implementations - registry.register_many(impls) - - return impls From a15dd97bf5c0007635569248bd902b7dfc30b15a Mon Sep 17 00:00:00 2001 From: xin2an Date: Wed, 11 Feb 2026 11:58:48 +0800 Subject: [PATCH 29/34] Revised according to feedback --- README.md | 4 +-- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 28 +++------------ .../backends/vendor/metax/impl/activation.py | 14 -------- .../vendor/metax/impl/normalization.py | 27 --------------- .../backends/vendor/metax/impl/rotary.py | 34 ------------------- .../dispatch/backends/vendor/metax/metax.py | 24 +++---------- vllm_fl/dispatch/config/metax.yaml | 8 ++--- 7 files changed, 16 insertions(+), 123 deletions(-) diff --git a/README.md b/README.md index 84601817..bdc6b5de 100644 --- a/README.md +++ b/README.md @@ -103,8 +103,8 @@ if __name__ == '__main__': For dispatch environment variable usage, see [environment variables usage](./vllm_fl/dispatch/README.md#environment-variables). -### Using CudaCommunication library -If you want to use the original CudaCommunication, you can unset the following environment variables. +### Using Cuda Communication library +If you want to use the original Cuda Communication, you can unset the following environment variables. ```sh unset FLAGCX_PATH ``` diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index b6d76088..ef5bf67b 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -48,29 +48,11 @@ def is_available(self) -> bool: CudaBackend._available = False return False - # Use current_platform's vendor information to check if this is NVIDIA - # This decouples from device name string matching and properly - # distinguishes NVIDIA GPUs from CUDA-alike devices - try: - from vllm.platforms import current_platform - - # Only enable CUDA backend for NVIDIA vendor - # CUDA-alike devices (MACA/metax, MUSA/mthreads, etc.) - # have their own vendor names and should use their own backends - if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "nvidia": - CudaBackend._available = True - else: - CudaBackend._available = False - - except Exception: - # Fallback: if platform detection fails, check device name - # This ensures backward compatibility - device_name = torch.cuda.get_device_name(0).upper() - if "NVIDIA" in device_name: - CudaBackend._available = True - else: - CudaBackend._available = False - + from vllm.platforms import current_platform + if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "nvidia": + CudaBackend._available = True + else: + CudaBackend._available = False except Exception: CudaBackend._available = False return CudaBackend._available diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py b/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py index 5d531521..154c8e2b 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py @@ -10,18 +10,4 @@ def silu_and_mul_metax(instance, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication using METAX/MACA. - Args: - instance: The calling instance (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - # TODO: Implement METAX-specific optimized version - # For now, use PyTorch reference implementation - d = x.shape[-1] // 2 - x1, x2 = x[..., :d], x[..., d:] - return torch.nn.functional.silu(x1) * x2 diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py index d11beee0..0db662a6 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py @@ -16,31 +16,4 @@ def rms_norm_metax( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization using METAX/MACA. - Args: - instance: The calling instance (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - # Get weight and epsilon from instance - weight = instance.weight - epsilon = instance.variance_epsilon - - # TODO: Implement METAX-specific optimized version - # For now, use PyTorch reference implementation - if residual is not None: - x = x + residual - residual = x - - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + epsilon) - x = x * weight - - if residual is not None: - return x, residual - return x diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py index ddbe5bd2..68aecabc 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py @@ -19,38 +19,4 @@ def rotary_embedding_metax( rotary_interleaved: bool = False, inplace: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding using METAX/MACA. - Args: - instance: The calling instance (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - # TODO: Implement METAX-specific optimized version - # For now, use PyTorch reference implementation - - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # Gather cos and sin based on position_ids - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - return apply_rotary_pos_emb(query, key, cos, sin, position_ids) diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py index 57f5bcc7..87eb0385 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/metax.py +++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py @@ -44,29 +44,15 @@ def is_available(self) -> bool: """ if MetaxBackend._available is None: try: - # Check if CUDA device is available (MACA is CUDA-compatible) if not torch.cuda.is_available() or torch.cuda.device_count() == 0: MetaxBackend._available = False return False - # Use current_platform's vendor information to check if this is METAX - try: - from vllm.platforms import current_platform - - # Only enable METAX backend for metax vendor - if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "metax": - MetaxBackend._available = True - else: - MetaxBackend._available = False - - except Exception: - # Fallback: check device name for MACA/METAX keywords - device_name = torch.cuda.get_device_name(0).upper() - if "MACA" in device_name or "METAX" in device_name or "MOORE" in device_name: - MetaxBackend._available = True - else: - MetaxBackend._available = False - + from vllm.platforms import current_platform + if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "metax": + MetaxBackend._available = True + else: + MetaxBackend._available = False except Exception: MetaxBackend._available = False return MetaxBackend._available diff --git a/vllm_fl/dispatch/config/metax.yaml b/vllm_fl/dispatch/config/metax.yaml index c9ebdca4..b8141c26 100644 --- a/vllm_fl/dispatch/config/metax.yaml +++ b/vllm_fl/dispatch/config/metax.yaml @@ -28,22 +28,22 @@ strict: false op_backends: # Prefer vendor implementation for all operators silu_and_mul: - - vendor:metax + - vendor - flaggems - reference rms_norm: - - vendor:metax + - vendor - flaggems - reference rotary_embedding: - - vendor:metax + - vendor - flaggems - reference attention_backend: - - vendor:metax + - vendor - flaggems # FlagOS operator blacklist (Optional) From 415eb1554268beee9de08b89e36429f39970319e Mon Sep 17 00:00:00 2001 From: xin2an Date: Wed, 11 Feb 2026 12:02:15 +0800 Subject: [PATCH 30/34] instance > obj --- vllm_fl/dispatch/backends/flaggems/flaggems.py | 18 +++++++++--------- .../backends/flaggems/impl/activation.py | 4 ++-- .../backends/flaggems/impl/normalization.py | 10 +++++----- .../dispatch/backends/flaggems/impl/rotary.py | 4 ++-- .../backends/reference/impl/activation.py | 4 ++-- .../backends/reference/impl/normalization.py | 10 +++++----- .../dispatch/backends/reference/impl/rotary.py | 4 ++-- .../dispatch/backends/reference/reference.py | 18 +++++++++--------- .../dispatch/backends/vendor/ascend/ascend.py | 18 +++++++++--------- .../backends/vendor/ascend/impl/activation.py | 4 ++-- .../vendor/ascend/impl/attention_mask.py | 8 ++++---- .../vendor/ascend/impl/normalization.py | 10 +++++----- .../backends/vendor/ascend/impl/rotary.py | 4 ++-- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 18 +++++++++--------- .../backends/vendor/cuda/impl/activation.py | 4 ++-- .../backends/vendor/cuda/impl/normalization.py | 10 +++++----- .../backends/vendor/cuda/impl/rotary.py | 4 ++-- .../backends/vendor/metax/impl/activation.py | 2 +- .../vendor/metax/impl/normalization.py | 2 +- .../backends/vendor/metax/impl/rotary.py | 2 +- .../dispatch/backends/vendor/metax/metax.py | 18 +++++++++--------- 21 files changed, 88 insertions(+), 88 deletions(-) diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 1beb8bda..952892f4 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -42,12 +42,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -55,11 +55,11 @@ def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_flaggems - return silu_and_mul_flaggems(instance, x) + return silu_and_mul_flaggems(obj, x) def rms_norm( self, - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -67,7 +67,7 @@ def rms_norm( RMS normalization. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -76,11 +76,11 @@ def rms_norm( """ from .impl.normalization import rms_norm_flaggems - return rms_norm_flaggems(instance, x, residual) + return rms_norm_flaggems(obj, x, residual) def rotary_embedding( self, - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -93,7 +93,7 @@ def rotary_embedding( Apply rotary position embedding. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -108,7 +108,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_flaggems return rotary_embedding_flaggems( - instance, + obj, query, key, cos, diff --git a/vllm_fl/dispatch/backends/flaggems/impl/activation.py b/vllm_fl/dispatch/backends/flaggems/impl/activation.py index db226f5b..08886fe1 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/activation.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/activation.py @@ -9,12 +9,12 @@ import torch -def silu_and_mul_flaggems(instance, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_flaggems(obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using FlagGems. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py index 2b09df43..c2e69050 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_flaggems( - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,7 +20,7 @@ def rms_norm_flaggems( RMS normalization using FlagGems. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -29,8 +29,8 @@ def rms_norm_flaggems( """ from flag_gems.modules.normalization import gems_rms_forward - # Get weight and epsilon from instance - weight = instance.weight - epsilon = instance.variance_epsilon + # Get weight and epsilon from obj + weight = obj.weight + epsilon = obj.variance_epsilon return gems_rms_forward(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py index 138c143a..b4cb5c30 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_flaggems( - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,7 +23,7 @@ def rotary_embedding_flaggems( Apply rotary position embedding using FlagGems. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/reference/impl/activation.py b/vllm_fl/dispatch/backends/reference/impl/activation.py index a3faa5f8..87f15061 100644 --- a/vllm_fl/dispatch/backends/reference/impl/activation.py +++ b/vllm_fl/dispatch/backends/reference/impl/activation.py @@ -10,12 +10,12 @@ import torch.nn.functional as F -def silu_and_mul_torch(instance, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_torch(obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using PyTorch. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/reference/impl/normalization.py b/vllm_fl/dispatch/backends/reference/impl/normalization.py index 7ff3c8cd..828018bc 100644 --- a/vllm_fl/dispatch/backends/reference/impl/normalization.py +++ b/vllm_fl/dispatch/backends/reference/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_torch( - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,16 +20,16 @@ def rms_norm_torch( RMS normalization using PyTorch. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ - # Get weight and epsilon from instance - weight = instance.weight - epsilon = instance.variance_epsilon + # Get weight and epsilon from obj + weight = obj.weight + epsilon = obj.variance_epsilon if residual is not None: x = x + residual diff --git a/vllm_fl/dispatch/backends/reference/impl/rotary.py b/vllm_fl/dispatch/backends/reference/impl/rotary.py index 707585a7..16125e08 100644 --- a/vllm_fl/dispatch/backends/reference/impl/rotary.py +++ b/vllm_fl/dispatch/backends/reference/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_torch( - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,7 +23,7 @@ def rotary_embedding_torch( Apply rotary position embedding using PyTorch. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] key: Key tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] cos: Cosine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 533f0f3a..653e905c 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -44,12 +44,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -57,11 +57,11 @@ def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_torch - return silu_and_mul_torch(instance, x) + return silu_and_mul_torch(obj, x) def rms_norm( self, - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -69,7 +69,7 @@ def rms_norm( RMS normalization. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -78,11 +78,11 @@ def rms_norm( """ from .impl.normalization import rms_norm_torch - return rms_norm_torch(instance, x, residual) + return rms_norm_torch(obj, x, residual) def rotary_embedding( self, - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -95,7 +95,7 @@ def rotary_embedding( Apply rotary position embedding. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -110,7 +110,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_torch return rotary_embedding_torch( - instance, + obj, query, key, cos, diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index 3a16037c..6646f589 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -51,12 +51,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -64,11 +64,11 @@ def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_ascend - return silu_and_mul_ascend(instance, x) + return silu_and_mul_ascend(obj, x) def rms_norm( self, - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -76,7 +76,7 @@ def rms_norm( RMS normalization. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -85,11 +85,11 @@ def rms_norm( """ from .impl.normalization import rms_norm_ascend - return rms_norm_ascend(instance, x, residual) + return rms_norm_ascend(obj, x, residual) def rotary_embedding( self, - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -102,7 +102,7 @@ def rotary_embedding( Apply rotary position embedding. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -117,7 +117,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_ascend return rotary_embedding_ascend( - instance, + obj, query, key, cos, diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py index 26bb1373..72ad09a6 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py @@ -9,12 +9,12 @@ import torch -def silu_and_mul_ascend(instance, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_ascend(obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using Ascend NPU. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention_mask.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention_mask.py index 0fc3d043..6c9af534 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention_mask.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention_mask.py @@ -143,23 +143,23 @@ def clear_cache(cls) -> None: cls._pcp_mla_mask_dtype = None -# Global instance cache for convenience +# Global obj cache for convenience _builder_instance: Optional[AttentionMaskBuilder] = None _builder_device: Optional[torch.device] = None def get_attention_mask_builder(device: torch.device) -> AttentionMaskBuilder: """ - Get or create a global AttentionMaskBuilder instance. + Get or create a global AttentionMaskBuilder obj. This function provides a convenient way to access the mask builder - without managing instance lifecycle. + without managing obj lifecycle. Args: device: The device for the mask builder. Returns: - AttentionMaskBuilder instance. + AttentionMaskBuilder obj. """ global _builder_instance, _builder_device if _builder_instance is None or _builder_device != device: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py index b0b2a940..8bcc2672 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_ascend( - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,7 +20,7 @@ def rms_norm_ascend( RMS normalization using Ascend NPU. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization @@ -29,9 +29,9 @@ def rms_norm_ascend( """ import torch_npu - # Get weight and epsilon from instance - weight = instance.weight - epsilon = instance.variance_epsilon + # Get weight and epsilon from obj + weight = obj.weight + epsilon = obj.variance_epsilon if residual is not None: x, _, residual = torch_npu.npu_add_rms_norm(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py index ce856daa..6fa6e3f9 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -11,7 +11,7 @@ def rotary_embedding_ascend( - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -24,7 +24,7 @@ def rotary_embedding_ascend( Apply rotary position embedding using Ascend NPU. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor [num_tokens, num_heads, rotary_dim] key: Key tensor [num_tokens, num_kv_heads, rotary_dim] cos: Cosine cache [max_seq_len, rotary_dim // 2] diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index ef5bf67b..66f5af06 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -59,14 +59,14 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Uses vLLM's native CUDA implementation. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -74,11 +74,11 @@ def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_cuda - return silu_and_mul_cuda(instance, x) + return silu_and_mul_cuda(obj, x) def rms_norm( self, - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -86,7 +86,7 @@ def rms_norm( RMS normalization using vLLM's CUDA implementation. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -95,11 +95,11 @@ def rms_norm( """ from .impl.normalization import rms_norm_cuda - return rms_norm_cuda(instance, x, residual) + return rms_norm_cuda(obj, x, residual) def rotary_embedding( self, - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -112,7 +112,7 @@ def rotary_embedding( Apply rotary position embedding using vLLM's CUDA implementation. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -127,7 +127,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_cuda return rotary_embedding_cuda( - instance, + obj, query, key, cos, diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index 09e35fcd..4f545cc7 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -9,14 +9,14 @@ import torch -def silu_and_mul_cuda(instance, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_cuda(obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using CUDA. Uses vLLM's optimized CUDA kernel when available. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py index f6e0f651..c43a7cc8 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_cuda( - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -22,7 +22,7 @@ def rms_norm_cuda( Uses vLLM's optimized CUDA kernel when available. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization @@ -32,9 +32,9 @@ def rms_norm_cuda( from vllm._custom_ops import rms_norm as vllm_rms_norm from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm - # Get weight and epsilon from instance - weight = instance.weight - epsilon = instance.variance_epsilon + # Get weight and epsilon from obj + weight = obj.weight + epsilon = obj.variance_epsilon if residual is not None: vllm_fused_add_rms_norm(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py index f816a213..73db40aa 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_cuda( - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -25,7 +25,7 @@ def rotary_embedding_cuda( Uses vLLM's optimized CUDA kernel when available. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py b/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py index 154c8e2b..e22eb753 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/activation.py @@ -9,5 +9,5 @@ import torch -def silu_and_mul_metax(instance, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_metax(obj, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py index 0db662a6..60a60d98 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_metax( - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py index 68aecabc..78d09bf7 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_metax( - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py index 87eb0385..7631e0fd 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/metax.py +++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py @@ -59,12 +59,12 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: + def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: @@ -72,11 +72,11 @@ def silu_and_mul(self, instance, x: torch.Tensor) -> torch.Tensor: """ from .impl.activation import silu_and_mul_metax - return silu_and_mul_metax(instance, x) + return silu_and_mul_metax(obj, x) def rms_norm( self, - instance, + obj, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -84,7 +84,7 @@ def rms_norm( RMS normalization. Args: - instance: The calling instance (e.g., RMSNorm layer) + obj: The calling obj (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -93,11 +93,11 @@ def rms_norm( """ from .impl.normalization import rms_norm_metax - return rms_norm_metax(instance, x, residual) + return rms_norm_metax(obj, x, residual) def rotary_embedding( self, - instance, + obj, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -110,7 +110,7 @@ def rotary_embedding( Apply rotary position embedding. Args: - instance: The calling instance (for interface consistency) + obj: The calling obj (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache @@ -125,7 +125,7 @@ def rotary_embedding( from .impl.rotary import rotary_embedding_metax return rotary_embedding_metax( - instance, + obj, query, key, cos, From 45da20ce99d80ffac7fc73f314e98e1b6d3db7dd Mon Sep 17 00:00:00 2001 From: xin2an Date: Wed, 11 Feb 2026 12:04:38 +0800 Subject: [PATCH 31/34] Revised according to feedback --- vllm_fl/dispatch/backends/vendor/metax/metax.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py index 7631e0fd..cfccb78b 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/metax.py +++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py @@ -44,10 +44,6 @@ def is_available(self) -> bool: """ if MetaxBackend._available is None: try: - if not torch.cuda.is_available() or torch.cuda.device_count() == 0: - MetaxBackend._available = False - return False - from vllm.platforms import current_platform if hasattr(current_platform, 'vendor_name') and current_platform.vendor_name == "metax": MetaxBackend._available = True From 23137a367285027eff8c31660d04386328e60632 Mon Sep 17 00:00:00 2001 From: xin2an Date: Wed, 11 Feb 2026 16:35:53 +0800 Subject: [PATCH 32/34] The VLLM_FL_PLATFORM environment variable was deleted. The default configuration file is now selected based on the current_platform. --- vllm_fl/dispatch/README.md | 15 ++++-------- .../config/{cuda.yaml => nvidia.yaml} | 4 ++-- vllm_fl/dispatch/config/utils.py | 23 ++++++++----------- 3 files changed, 16 insertions(+), 26 deletions(-) rename vllm_fl/dispatch/config/{cuda.yaml => nvidia.yaml} (91%) diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 6a382901..f252ce38 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -227,14 +227,11 @@ The system automatically detects hardware and loads the corresponding configurat | Platform | Config File | Auto-Detection | |----------|-------------|----------------| -| Ascend NPU | `config/ascend.yaml` | `torch.npu.is_available()` | -| NVIDIA GPU | `config/cuda.yaml` | `torch.cuda.is_available()` | +| Ascend NPU | `config/ascend.yaml` | `platform.vendor_name == 'ascend'` | +| NVIDIA GPU | `config/nvidia.yaml` | `platform.vendor_name == 'nvidia'` | +| METAX GPU | `config/metax.yaml` | `platform.vendor_name == 'metax'` | -You can force a specific platform using `VLLM_FL_PLATFORM` environment variable: -```bash -export VLLM_FL_PLATFORM=ascend # Force Ascend config -export VLLM_FL_PLATFORM=cuda # Force CUDA config -``` +Platform detection is automatic based on `current_platform.vendor_name`. ### User-Specified Configuration File (YAML) @@ -314,7 +311,6 @@ Environment variables can override specific items from platform config. If not s |----------|---------|-------------| | `VLLM_FL_PREFER_ENABLED` | `true` | Global switch. Set `false` to disable all dispatch features | | `VLLM_FL_CONFIG` | (none) | Path to YAML config file (complete override) | -| `VLLM_FL_PLATFORM` | (auto) | Force platform: `ascend`, `cuda` | #### Backend Selection @@ -388,9 +384,6 @@ export VLLM_FL_PER_OP="rms_norm=vendor|flagos|reference" # Use completely custom config file export VLLM_FL_CONFIG=/path/to/my_config.yaml -# Force specific platform -export VLLM_FL_PLATFORM=ascend - # Enable debug logging export VLLM_FL_LOG_LEVEL=DEBUG ``` diff --git a/vllm_fl/dispatch/config/cuda.yaml b/vllm_fl/dispatch/config/nvidia.yaml similarity index 91% rename from vllm_fl/dispatch/config/cuda.yaml rename to vllm_fl/dispatch/config/nvidia.yaml index 3c8676ff..2a192b14 100644 --- a/vllm_fl/dispatch/config/cuda.yaml +++ b/vllm_fl/dispatch/config/nvidia.yaml @@ -1,5 +1,5 @@ -# vLLM-FL Dispatch Configuration for CUDA -# Auto-loaded when running on NVIDIA GPU hardware +# vLLM-FL Dispatch Configuration for NVIDIA GPU +# Auto-loaded when running on NVIDIA GPU hardware (vendor_name: nvidia) # Preferred default backend type: flaggems, vendor, reference prefer: flagos diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py index a0e41db6..999a4324 100644 --- a/vllm_fl/dispatch/config/utils.py +++ b/vllm_fl/dispatch/config/utils.py @@ -41,24 +41,21 @@ def get_platform_name() -> str: """ - Detect the current hardware platform. + Detect the current hardware platform using platform vendor_name. + + This function uses current_platform.vendor_name to accurately distinguish + between different hardware vendors (NVIDIA, METAX, Ascend, etc.). Returns: - Platform name string: 'ascend', 'cuda', or 'unknown' + Platform name string based on vendor_name: 'nvidia', 'metax', 'ascend', etc. """ try: - import torch - if hasattr(torch, 'npu') and torch.npu.is_available(): - return 'ascend' - if torch.cuda.is_available(): - return 'cuda' - except ImportError: - pass + from vllm.platforms import current_platform - # Check environment variable override - platform_override = os.environ.get('VLLM_FL_PLATFORM', '').strip().lower() - if platform_override: - return platform_override + if hasattr(current_platform, 'vendor_name'): + return current_platform.vendor_name + except Exception: + pass return 'unknown' From a7fb4c48cd14dee46a101ea451523ead05970fde Mon Sep 17 00:00:00 2001 From: xin2an Date: Wed, 11 Feb 2026 20:13:51 +0800 Subject: [PATCH 33/34] Decouple op implementations from Backend classes and add dispatch_method descriptor for instance-aware dispatch --- vllm_fl/dispatch/__init__.py | 14 +++ .../dispatch/backends/flaggems/flaggems.py | 78 +------------ .../backends/flaggems/impl/activation.py | 4 +- .../backends/flaggems/impl/normalization.py | 10 +- .../dispatch/backends/flaggems/impl/rotary.py | 4 +- .../backends/flaggems/register_ops.py | 11 +- .../backends/reference/impl/activation.py | 4 +- .../backends/reference/impl/normalization.py | 10 +- .../backends/reference/impl/rotary.py | 4 +- .../dispatch/backends/reference/reference.py | 80 +------------ .../backends/reference/register_ops.py | 11 +- .../dispatch/backends/vendor/ascend/ascend.py | 78 +------------ .../backends/vendor/ascend/impl/activation.py | 4 +- .../vendor/ascend/impl/normalization.py | 10 +- .../backends/vendor/ascend/impl/rotary.py | 4 +- .../backends/vendor/ascend/register_ops.py | 11 +- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 80 +------------ .../backends/vendor/cuda/impl/activation.py | 4 +- .../vendor/cuda/impl/normalization.py | 10 +- .../backends/vendor/cuda/impl/rotary.py | 4 +- .../backends/vendor/cuda/register_ops.py | 11 +- vllm_fl/dispatch/manager.py | 110 ++++++++++++++++++ vllm_fl/dispatch/method_dispatch.py | 44 +++++++ vllm_fl/ops/activation.py | 6 +- vllm_fl/ops/layernorm.py | 11 +- vllm_fl/ops/rotary_embedding.py | 4 +- 26 files changed, 243 insertions(+), 378 deletions(-) create mode 100644 vllm_fl/dispatch/method_dispatch.py diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index 078f387f..d2916dcb 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -96,6 +96,7 @@ ) from .manager import OpManager, get_default_manager, reset_default_manager from .ops import VLLMFLBackendBase +from .method_dispatch import dispatch_method from .discovery import ( discover_plugins, get_discovered_plugins, @@ -106,6 +107,16 @@ from .logger_manager import get_logger, set_log_level +def call_method_op(op_name: str, instance, *args, **kwargs): + """ + Call an operator as a bound method on *instance*. + + The resolved backend function receives *instance* as ``self``, + allowing it to freely access instance attributes. + """ + return get_default_manager().call_as_method(op_name, instance, *args, **kwargs) + + def call_op(op_name: str, *args, **kwargs): """ Convenience function to call an operator through the default manager. @@ -163,6 +174,9 @@ def resolve_op(op_name: str): "reset_default_manager", # Backend base "VLLMFLBackendBase", + # Method dispatch + "dispatch_method", + "call_method_op", # Plugin discovery "discover_plugins", "get_discovered_plugins", diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 952892f4..37c54c0d 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import torch @@ -42,82 +42,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_flaggems - - return silu_and_mul_flaggems(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_flaggems - - return rms_norm_flaggems(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_flaggems - - return rotary_embedding_flaggems( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for FlagGems. diff --git a/vllm_fl/dispatch/backends/flaggems/impl/activation.py b/vllm_fl/dispatch/backends/flaggems/impl/activation.py index 08886fe1..146446de 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/activation.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/activation.py @@ -9,12 +9,12 @@ import torch -def silu_and_mul_flaggems(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_flaggems(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using FlagGems. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py index c2e69050..71115ef9 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_flaggems( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,7 +20,7 @@ def rms_norm_flaggems( RMS normalization using FlagGems. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -29,8 +29,8 @@ def rms_norm_flaggems( """ from flag_gems.modules.normalization import gems_rms_forward - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon return gems_rms_forward(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py index b4cb5c30..35420b1e 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_flaggems( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,7 +23,7 @@ def rotary_embedding_flaggems( Apply rotary position embedding using FlagGems. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index bc5595b3..a2e98b26 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -34,6 +34,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .flaggems import FlagGemsBackend + from .impl.activation import silu_and_mul_flaggems + from .impl.normalization import rms_norm_flaggems + from .impl.rotary import rotary_embedding_flaggems backend = FlagGemsBackend() is_avail = backend.is_available @@ -44,7 +47,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_flaggems, is_avail), vendor=None, priority=BackendPriority.DEFAULT, ), @@ -53,7 +56,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_flaggems, is_avail), vendor=None, priority=BackendPriority.DEFAULT, ), @@ -62,11 +65,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_flaggems, is_avail), vendor=None, priority=BackendPriority.DEFAULT, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="default.flagos", diff --git a/vllm_fl/dispatch/backends/reference/impl/activation.py b/vllm_fl/dispatch/backends/reference/impl/activation.py index 87f15061..ce8e9fd9 100644 --- a/vllm_fl/dispatch/backends/reference/impl/activation.py +++ b/vllm_fl/dispatch/backends/reference/impl/activation.py @@ -10,12 +10,12 @@ import torch.nn.functional as F -def silu_and_mul_torch(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_torch(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using PyTorch. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/reference/impl/normalization.py b/vllm_fl/dispatch/backends/reference/impl/normalization.py index 828018bc..68e17cac 100644 --- a/vllm_fl/dispatch/backends/reference/impl/normalization.py +++ b/vllm_fl/dispatch/backends/reference/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_torch( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,16 +20,16 @@ def rms_norm_torch( RMS normalization using PyTorch. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon if residual is not None: x = x + residual diff --git a/vllm_fl/dispatch/backends/reference/impl/rotary.py b/vllm_fl/dispatch/backends/reference/impl/rotary.py index 16125e08..a0c8a557 100644 --- a/vllm_fl/dispatch/backends/reference/impl/rotary.py +++ b/vllm_fl/dispatch/backends/reference/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_torch( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,7 +23,7 @@ def rotary_embedding_torch( Apply rotary position embedding using PyTorch. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] key: Key tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] cos: Cosine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 653e905c..966b8638 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -10,9 +10,7 @@ from __future__ import annotations -from typing import Optional, Union - -import torch +from typing import Optional from vllm_fl.dispatch.backends.base import Backend @@ -44,82 +42,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_torch - - return silu_and_mul_torch(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_torch - - return rms_norm_torch(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place (ignored in reference impl) - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_torch - - return rotary_embedding_torch( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for reference (vLLM native). diff --git a/vllm_fl/dispatch/backends/reference/register_ops.py b/vllm_fl/dispatch/backends/reference/register_ops.py index 522474c3..fa017402 100644 --- a/vllm_fl/dispatch/backends/reference/register_ops.py +++ b/vllm_fl/dispatch/backends/reference/register_ops.py @@ -32,6 +32,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .reference import ReferenceBackend + from .impl.activation import silu_and_mul_torch + from .impl.normalization import rms_norm_torch + from .impl.rotary import rotary_embedding_torch backend = ReferenceBackend() is_avail = backend.is_available @@ -42,7 +45,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="reference.torch", kind=BackendImplKind.REFERENCE, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_torch, is_avail), vendor=None, priority=BackendPriority.REFERENCE, ), @@ -51,7 +54,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="reference.torch", kind=BackendImplKind.REFERENCE, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_torch, is_avail), vendor=None, priority=BackendPriority.REFERENCE, ), @@ -60,11 +63,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="reference.torch", kind=BackendImplKind.REFERENCE, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_torch, is_avail), vendor=None, priority=BackendPriority.REFERENCE, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="reference.torch", diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index 6646f589..1c407483 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import torch @@ -51,82 +51,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_ascend - - return silu_and_mul_ascend(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_ascend - - return rms_norm_ascend(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_ascend - - return rotary_embedding_ascend( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for Ascend NPU. diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py index 72ad09a6..38a2fda1 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py @@ -9,12 +9,12 @@ import torch -def silu_and_mul_ascend(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_ascend(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using Ascend NPU. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py index 8bcc2672..7c277205 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_ascend( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,7 +20,7 @@ def rms_norm_ascend( RMS normalization using Ascend NPU. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization @@ -29,9 +29,9 @@ def rms_norm_ascend( """ import torch_npu - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon if residual is not None: x, _, residual = torch_npu.npu_add_rms_norm(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py index 6fa6e3f9..aa9ae581 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -11,7 +11,7 @@ def rotary_embedding_ascend( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -24,7 +24,7 @@ def rotary_embedding_ascend( Apply rotary position embedding using Ascend NPU. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor [num_tokens, num_heads, rotary_dim] key: Key tensor [num_tokens, num_kv_heads, rotary_dim] cos: Cosine cache [max_seq_len, rotary_dim // 2] diff --git a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py index 3834a215..f596bd52 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py @@ -32,6 +32,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .ascend import AscendBackend + from .impl.activation import silu_and_mul_ascend + from .impl.normalization import rms_norm_ascend + from .impl.rotary import rotary_embedding_ascend backend = AscendBackend() is_avail = backend.is_available @@ -42,7 +45,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="vendor.ascend", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_ascend, is_avail), vendor="ascend", priority=BackendPriority.VENDOR, ), @@ -51,7 +54,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="vendor.ascend", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_ascend, is_avail), vendor="ascend", priority=BackendPriority.VENDOR, ), @@ -60,11 +63,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="vendor.ascend", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_ascend, is_avail), vendor="ascend", priority=BackendPriority.VENDOR, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="vendor.ascend", diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 66f5af06..d628140d 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import torch @@ -59,84 +59,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Uses vLLM's native CUDA implementation. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_cuda - - return silu_and_mul_cuda(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization using vLLM's CUDA implementation. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_cuda - - return rms_norm_cuda(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding using vLLM's CUDA implementation. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_cuda - - return rotary_embedding_cuda( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for CUDA. diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index 4f545cc7..0ab49aed 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -9,14 +9,14 @@ import torch -def silu_and_mul_cuda(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_cuda(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using CUDA. Uses vLLM's optimized CUDA kernel when available. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py index c43a7cc8..fe2d36de 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_cuda( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -22,7 +22,7 @@ def rms_norm_cuda( Uses vLLM's optimized CUDA kernel when available. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization @@ -32,9 +32,9 @@ def rms_norm_cuda( from vllm._custom_ops import rms_norm as vllm_rms_norm from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon if residual is not None: vllm_fused_add_rms_norm(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py index 73db40aa..fe46e9c2 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_cuda( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -25,7 +25,7 @@ def rotary_embedding_cuda( Uses vLLM's optimized CUDA kernel when available. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py b/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py index d0241715..41c8e8c2 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py @@ -32,6 +32,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .cuda import CudaBackend + from .impl.activation import silu_and_mul_cuda + from .impl.normalization import rms_norm_cuda + from .impl.rotary import rotary_embedding_cuda backend = CudaBackend() is_avail = backend.is_available @@ -42,7 +45,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="vendor.cuda", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_cuda, is_avail), vendor="cuda", priority=BackendPriority.VENDOR, ), @@ -51,7 +54,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="vendor.cuda", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_cuda, is_avail), vendor="cuda", priority=BackendPriority.VENDOR, ), @@ -60,11 +63,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="vendor.cuda", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_cuda, is_avail), vendor="cuda", priority=BackendPriority.VENDOR, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="vendor.cuda", diff --git a/vllm_fl/dispatch/manager.py b/vllm_fl/dispatch/manager.py index db75db4e..f950a11a 100644 --- a/vllm_fl/dispatch/manager.py +++ b/vllm_fl/dispatch/manager.py @@ -9,6 +9,7 @@ import logging import os import threading +import types as pytypes from dataclasses import dataclass from typing import Callable, Dict, Optional, Set, Tuple @@ -521,6 +522,115 @@ def call(self, op_name: str, *args, **kwargs): f"Last error: {last_error}" ) from last_error + def call_as_method(self, op_name: str, instance, *args, **kwargs): + """ + Resolve and call an operator as a bound method on *instance*. + + Behaves identically to :meth:`call` (fallback, logging, caching) + except that the resolved function is bound to *instance* via + ``types.MethodType`` before invocation, so the backend function + receives *instance* as ``self``. + """ + enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" + + if not enable_fallback: + fn = self.resolve(op_name) + + impl_id = self.get_selected_impl_id(op_name) + last_impl_id = self._called_ops.get(op_name) + + if last_impl_id != impl_id: + with self._lock: + if self._called_ops.get(op_name) != impl_id: + snap = self._registry.snapshot() + for impl in snap.impls_by_op.get(op_name, []): + if impl.impl_id == impl_id: + if last_impl_id is None: + logger.info( + f"Op '{op_name}' using '{impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + break + self._called_ops[op_name] = impl_id + + bound = pytypes.MethodType(fn, instance) + return bound(*args, **kwargs) + + # Fallback mode: try candidates in priority order + candidates = self.resolve_candidates(op_name) + last_error = None + + failed_impl_ids = self._failed_impls.get(op_name, set()) + + available_candidates = [ + impl for impl in candidates if impl.impl_id not in failed_impl_ids + ] + + if not available_candidates: + raise RuntimeError( + f"All implementations for op='{op_name}' have failed previously. " + f"Failed impl_ids: {failed_impl_ids}" + ) + + for idx, impl in enumerate(available_candidates): + try: + if idx == 0: + last_impl_id = self._called_ops.get(op_name) + if last_impl_id != impl.impl_id: + with self._lock: + if self._called_ops.get(op_name) != impl.impl_id: + if last_impl_id is None: + logger.info( + f"Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + self._called_ops[op_name] = impl.impl_id + else: + logger.info( + f"Op '{op_name}' fallback to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + + bound = pytypes.MethodType(impl.fn, instance) + result = bound(*args, **kwargs) + + if idx > 0: + with self._lock: + self._called_ops[op_name] = impl.impl_id + + return result + + except Exception as e: + last_error = e + with self._lock: + if op_name not in self._failed_impls: + self._failed_impls[op_name] = set() + self._failed_impls[op_name].add(impl.impl_id) + + if idx < len(available_candidates) - 1: + logger.warning( + f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + else: + logger.error( + f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + + raise RuntimeError( + f"All {len(available_candidates)} implementation(s) failed for op='{op_name}'. " + f"Last error: {last_error}" + ) from last_error + def get_selected_impl_id(self, op_name: str) -> str: """ Get the impl_id of the currently selected implementation. diff --git a/vllm_fl/dispatch/method_dispatch.py b/vllm_fl/dispatch/method_dispatch.py new file mode 100644 index 00000000..bb8374e6 --- /dev/null +++ b/vllm_fl/dispatch/method_dispatch.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Descriptor-based method dispatch for operator implementations. + +Allows operator classes to declare `forward_oot` as a descriptor that +automatically dispatches to the resolved backend implementation, with +the backend function bound as a method so `self` is naturally available. +""" + +from __future__ import annotations + + +class dispatch_method: + """ + Descriptor that dispatches to the resolved backend implementation. + + The backend function is bound as a method to the operator instance + via ``types.MethodType``, so ``self`` is naturally available — just + like vLLM's ``forward_cuda`` / ``forward_xpu`` pattern. + + Usage:: + + class RMSNormFL(RMSNorm): + forward_oot = dispatch_method("rms_norm") + """ + + def __init__(self, op_name: str) -> None: + self.op_name = op_name + + def __set_name__(self, owner, name): + self.attr_name = name + + def __get__(self, obj, objtype=None): + if obj is None: + return self + + def dispatched(*args, **kwargs): + from vllm_fl.dispatch import get_default_manager + return get_default_manager().call_as_method( + self.op_name, obj, *args, **kwargs + ) + + return dispatched diff --git a/vllm_fl/ops/activation.py b/vllm_fl/ops/activation.py index e895c5b0..8032f6cb 100644 --- a/vllm_fl/ops/activation.py +++ b/vllm_fl/ops/activation.py @@ -1,15 +1,13 @@ # Copyright (c) 2025 BAAI. All rights reserved. -import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm_fl.dispatch import call_op +from vllm_fl.dispatch.method_dispatch import dispatch_method class SiluAndMulFL(SiluAndMul): def __init__(self): super().__init__() - def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - return call_op("silu_and_mul", self, x) + forward_oot = dispatch_method("silu_and_mul") __all__ = ["SiluAndMulFL"] diff --git a/vllm_fl/ops/layernorm.py b/vllm_fl/ops/layernorm.py index ea13c7c4..75c4c196 100644 --- a/vllm_fl/ops/layernorm.py +++ b/vllm_fl/ops/layernorm.py @@ -1,9 +1,9 @@ # Copyright (c) 2025 BAAI. All rights reserved. -from typing import Optional, Union +from typing import Optional import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm_fl.dispatch import call_op +from vllm_fl.dispatch.method_dispatch import dispatch_method class RMSNormFL(RMSNorm): @@ -17,12 +17,7 @@ def __init__( ) -> None: super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - def forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - return call_op("rms_norm", self, x, residual) + forward_oot = dispatch_method("rms_norm") __all__ = ["RMSNormFL"] diff --git a/vllm_fl/ops/rotary_embedding.py b/vllm_fl/ops/rotary_embedding.py index 98694bf9..a125f64c 100644 --- a/vllm_fl/ops/rotary_embedding.py +++ b/vllm_fl/ops/rotary_embedding.py @@ -3,7 +3,7 @@ from typing import Optional import torch from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm_fl.dispatch import call_op +from vllm_fl.dispatch import call_method_op class RotaryEmbeddingFL(RotaryEmbedding): @@ -44,7 +44,7 @@ def forward_oot( cos, sin = self.cos_sin_cache.chunk(2, dim=-1) - q_embed, k_embed = call_op( + q_embed, k_embed = call_method_op( "rotary_embedding", self, query_rot, From 90fb24eab130c0af100ef82b04ffafcbab2218c4 Mon Sep 17 00:00:00 2001 From: xin2an Date: Sat, 28 Feb 2026 19:20:12 +0800 Subject: [PATCH 34/34] Prevent CUDA detection for 'iluvatar' vendor Add check for 'iluvatar' vendor in CUDA availability methods. --- vllm_fl/platform.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 82ed06c8..299f282e 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -53,10 +53,14 @@ class PlatformFL(Platform): def is_cuda_alike(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" + if self.vendor_name == "iluvatar": + return False return self.device_type == "cuda" def is_cuda(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" + if self.vendor_name == "iluvatar": + return False return self.device_type == "cuda" @property