diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 6a382901..f252ce38 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -227,14 +227,11 @@ The system automatically detects hardware and loads the corresponding configurat | Platform | Config File | Auto-Detection | |----------|-------------|----------------| -| Ascend NPU | `config/ascend.yaml` | `torch.npu.is_available()` | -| NVIDIA GPU | `config/cuda.yaml` | `torch.cuda.is_available()` | +| Ascend NPU | `config/ascend.yaml` | `platform.vendor_name == 'ascend'` | +| NVIDIA GPU | `config/nvidia.yaml` | `platform.vendor_name == 'nvidia'` | +| METAX GPU | `config/metax.yaml` | `platform.vendor_name == 'metax'` | -You can force a specific platform using `VLLM_FL_PLATFORM` environment variable: -```bash -export VLLM_FL_PLATFORM=ascend # Force Ascend config -export VLLM_FL_PLATFORM=cuda # Force CUDA config -``` +Platform detection is automatic based on `current_platform.vendor_name`. ### User-Specified Configuration File (YAML) @@ -314,7 +311,6 @@ Environment variables can override specific items from platform config. If not s |----------|---------|-------------| | `VLLM_FL_PREFER_ENABLED` | `true` | Global switch. Set `false` to disable all dispatch features | | `VLLM_FL_CONFIG` | (none) | Path to YAML config file (complete override) | -| `VLLM_FL_PLATFORM` | (auto) | Force platform: `ascend`, `cuda` | #### Backend Selection @@ -388,9 +384,6 @@ export VLLM_FL_PER_OP="rms_norm=vendor|flagos|reference" # Use completely custom config file export VLLM_FL_CONFIG=/path/to/my_config.yaml -# Force specific platform -export VLLM_FL_PLATFORM=ascend - # Enable debug logging export VLLM_FL_LOG_LEVEL=DEBUG ``` diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index 078f387f..d2916dcb 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -96,6 +96,7 @@ ) from .manager import OpManager, get_default_manager, reset_default_manager from .ops import VLLMFLBackendBase +from .method_dispatch import dispatch_method from .discovery import ( discover_plugins, get_discovered_plugins, @@ -106,6 +107,16 @@ from .logger_manager import get_logger, set_log_level +def call_method_op(op_name: str, instance, *args, **kwargs): + """ + Call an operator as a bound method on *instance*. + + The resolved backend function receives *instance* as ``self``, + allowing it to freely access instance attributes. + """ + return get_default_manager().call_as_method(op_name, instance, *args, **kwargs) + + def call_op(op_name: str, *args, **kwargs): """ Convenience function to call an operator through the default manager. @@ -163,6 +174,9 @@ def resolve_op(op_name: str): "reset_default_manager", # Backend base "VLLMFLBackendBase", + # Method dispatch + "dispatch_method", + "call_method_op", # Plugin discovery "discover_plugins", "get_discovered_plugins", diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index 952892f4..37c54c0d 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import torch @@ -42,82 +42,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_flaggems - - return silu_and_mul_flaggems(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_flaggems - - return rms_norm_flaggems(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_flaggems - - return rotary_embedding_flaggems( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for FlagGems. diff --git a/vllm_fl/dispatch/backends/flaggems/impl/activation.py b/vllm_fl/dispatch/backends/flaggems/impl/activation.py index 08886fe1..146446de 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/activation.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/activation.py @@ -9,12 +9,12 @@ import torch -def silu_and_mul_flaggems(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_flaggems(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using FlagGems. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py index c2e69050..71115ef9 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/normalization.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_flaggems( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,7 +20,7 @@ def rms_norm_flaggems( RMS normalization using FlagGems. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor @@ -29,8 +29,8 @@ def rms_norm_flaggems( """ from flag_gems.modules.normalization import gems_rms_forward - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon return gems_rms_forward(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py index b4cb5c30..35420b1e 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/rotary.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_flaggems( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,7 +23,7 @@ def rotary_embedding_flaggems( Apply rotary position embedding using FlagGems. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/flaggems/register_ops.py b/vllm_fl/dispatch/backends/flaggems/register_ops.py index bc5595b3..a2e98b26 100644 --- a/vllm_fl/dispatch/backends/flaggems/register_ops.py +++ b/vllm_fl/dispatch/backends/flaggems/register_ops.py @@ -34,6 +34,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .flaggems import FlagGemsBackend + from .impl.activation import silu_and_mul_flaggems + from .impl.normalization import rms_norm_flaggems + from .impl.rotary import rotary_embedding_flaggems backend = FlagGemsBackend() is_avail = backend.is_available @@ -44,7 +47,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_flaggems, is_avail), vendor=None, priority=BackendPriority.DEFAULT, ), @@ -53,7 +56,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_flaggems, is_avail), vendor=None, priority=BackendPriority.DEFAULT, ), @@ -62,11 +65,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_flaggems, is_avail), vendor=None, priority=BackendPriority.DEFAULT, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="default.flagos", diff --git a/vllm_fl/dispatch/backends/reference/impl/activation.py b/vllm_fl/dispatch/backends/reference/impl/activation.py index 87f15061..ce8e9fd9 100644 --- a/vllm_fl/dispatch/backends/reference/impl/activation.py +++ b/vllm_fl/dispatch/backends/reference/impl/activation.py @@ -10,12 +10,12 @@ import torch.nn.functional as F -def silu_and_mul_torch(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_torch(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using PyTorch. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/reference/impl/normalization.py b/vllm_fl/dispatch/backends/reference/impl/normalization.py index 828018bc..68e17cac 100644 --- a/vllm_fl/dispatch/backends/reference/impl/normalization.py +++ b/vllm_fl/dispatch/backends/reference/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_torch( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,16 +20,16 @@ def rms_norm_torch( RMS normalization using PyTorch. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor Returns: Normalized tensor, or tuple of (normalized, residual) if residual is provided """ - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon if residual is not None: x = x + residual diff --git a/vllm_fl/dispatch/backends/reference/impl/rotary.py b/vllm_fl/dispatch/backends/reference/impl/rotary.py index 16125e08..a0c8a557 100644 --- a/vllm_fl/dispatch/backends/reference/impl/rotary.py +++ b/vllm_fl/dispatch/backends/reference/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_torch( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -23,7 +23,7 @@ def rotary_embedding_torch( Apply rotary position embedding using PyTorch. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] key: Key tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] cos: Cosine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 653e905c..966b8638 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -10,9 +10,7 @@ from __future__ import annotations -from typing import Optional, Union - -import torch +from typing import Optional from vllm_fl.dispatch.backends.base import Backend @@ -44,82 +42,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_torch - - return silu_and_mul_torch(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_torch - - return rms_norm_torch(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place (ignored in reference impl) - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_torch - - return rotary_embedding_torch( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for reference (vLLM native). diff --git a/vllm_fl/dispatch/backends/reference/register_ops.py b/vllm_fl/dispatch/backends/reference/register_ops.py index 522474c3..fa017402 100644 --- a/vllm_fl/dispatch/backends/reference/register_ops.py +++ b/vllm_fl/dispatch/backends/reference/register_ops.py @@ -32,6 +32,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .reference import ReferenceBackend + from .impl.activation import silu_and_mul_torch + from .impl.normalization import rms_norm_torch + from .impl.rotary import rotary_embedding_torch backend = ReferenceBackend() is_avail = backend.is_available @@ -42,7 +45,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="reference.torch", kind=BackendImplKind.REFERENCE, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_torch, is_avail), vendor=None, priority=BackendPriority.REFERENCE, ), @@ -51,7 +54,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="reference.torch", kind=BackendImplKind.REFERENCE, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_torch, is_avail), vendor=None, priority=BackendPriority.REFERENCE, ), @@ -60,11 +63,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="reference.torch", kind=BackendImplKind.REFERENCE, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_torch, is_avail), vendor=None, priority=BackendPriority.REFERENCE, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="reference.torch", diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index 6646f589..1c407483 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import torch @@ -51,82 +51,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_ascend - - return silu_and_mul_ascend(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_ascend - - return rms_norm_ascend(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_ascend - - return rotary_embedding_ascend( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for Ascend NPU. diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py index 72ad09a6..38a2fda1 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/activation.py @@ -9,12 +9,12 @@ import torch -def silu_and_mul_ascend(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_ascend(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using Ascend NPU. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py index 8bcc2672..7c277205 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_ascend( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -20,7 +20,7 @@ def rms_norm_ascend( RMS normalization using Ascend NPU. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization @@ -29,9 +29,9 @@ def rms_norm_ascend( """ import torch_npu - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon if residual is not None: x, _, residual = torch_npu.npu_add_rms_norm(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py index 6fa6e3f9..aa9ae581 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/rotary.py @@ -11,7 +11,7 @@ def rotary_embedding_ascend( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -24,7 +24,7 @@ def rotary_embedding_ascend( Apply rotary position embedding using Ascend NPU. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor [num_tokens, num_heads, rotary_dim] key: Key tensor [num_tokens, num_kv_heads, rotary_dim] cos: Cosine cache [max_seq_len, rotary_dim // 2] diff --git a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py index 3834a215..f596bd52 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py @@ -32,6 +32,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .ascend import AscendBackend + from .impl.activation import silu_and_mul_ascend + from .impl.normalization import rms_norm_ascend + from .impl.rotary import rotary_embedding_ascend backend = AscendBackend() is_avail = backend.is_available @@ -42,7 +45,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="vendor.ascend", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_ascend, is_avail), vendor="ascend", priority=BackendPriority.VENDOR, ), @@ -51,7 +54,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="vendor.ascend", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_ascend, is_avail), vendor="ascend", priority=BackendPriority.VENDOR, ), @@ -60,11 +63,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="vendor.ascend", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_ascend, is_avail), vendor="ascend", priority=BackendPriority.VENDOR, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="vendor.ascend", diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 66f5af06..d628140d 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import torch @@ -59,84 +59,6 @@ def is_available(self) -> bool: # ==================== Operator Implementations ==================== - def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: - """ - SiLU activation followed by element-wise multiplication. - - Uses vLLM's native CUDA implementation. - - Args: - obj: The calling obj (for interface consistency) - x: Input tensor of shape [..., 2*d] - - Returns: - Output tensor of shape [..., d] - """ - from .impl.activation import silu_and_mul_cuda - - return silu_and_mul_cuda(obj, x) - - def rms_norm( - self, - obj, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """ - RMS normalization using vLLM's CUDA implementation. - - Args: - obj: The calling obj (e.g., RMSNorm layer) - x: Input tensor - residual: Optional residual tensor - - Returns: - Normalized tensor, or tuple of (normalized, residual) if residual is provided - """ - from .impl.normalization import rms_norm_cuda - - return rms_norm_cuda(obj, x, residual) - - def rotary_embedding( - self, - obj, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - rotary_interleaved: bool = False, - inplace: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding using vLLM's CUDA implementation. - - Args: - obj: The calling obj (for interface consistency) - query: Query tensor - key: Key tensor - cos: Cosine cache - sin: Sine cache - position_ids: Position indices - rotary_interleaved: Whether to use interleaved rotary - inplace: Whether to modify tensors in-place - - Returns: - Tuple of (embedded_query, embedded_key) - """ - from .impl.rotary import rotary_embedding_cuda - - return rotary_embedding_cuda( - obj, - query, - key, - cos, - sin, - position_ids, - rotary_interleaved=rotary_interleaved, - inplace=inplace, - ) - def attention_backend(self, use_mla: bool = False) -> str: """ Get the attention backend class path for CUDA. diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index 4f545cc7..0ab49aed 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -9,14 +9,14 @@ import torch -def silu_and_mul_cuda(obj, x: torch.Tensor) -> torch.Tensor: +def silu_and_mul_cuda(self, x: torch.Tensor) -> torch.Tensor: """ SiLU activation followed by element-wise multiplication using CUDA. Uses vLLM's optimized CUDA kernel when available. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) x: Input tensor of shape [..., 2*d] Returns: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py index c43a7cc8..fe2d36de 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/normalization.py @@ -12,7 +12,7 @@ def rms_norm_cuda( - obj, + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: @@ -22,7 +22,7 @@ def rms_norm_cuda( Uses vLLM's optimized CUDA kernel when available. Args: - obj: The calling obj (e.g., RMSNorm layer) + self: The calling instance (e.g., RMSNorm layer) x: Input tensor residual: Optional residual tensor to add before normalization @@ -32,9 +32,9 @@ def rms_norm_cuda( from vllm._custom_ops import rms_norm as vllm_rms_norm from vllm._custom_ops import fused_add_rms_norm as vllm_fused_add_rms_norm - # Get weight and epsilon from obj - weight = obj.weight - epsilon = obj.variance_epsilon + # Get weight and epsilon from self + weight = self.weight + epsilon = self.variance_epsilon if residual is not None: vllm_fused_add_rms_norm(x, residual, weight, epsilon) diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py index 73db40aa..fe46e9c2 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/rotary.py @@ -10,7 +10,7 @@ def rotary_embedding_cuda( - obj, + self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, @@ -25,7 +25,7 @@ def rotary_embedding_cuda( Uses vLLM's optimized CUDA kernel when available. Args: - obj: The calling obj (for interface consistency) + self: The calling instance (for interface consistency) query: Query tensor key: Key tensor cos: Cosine cache diff --git a/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py b/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py index d0241715..41c8e8c2 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/register_ops.py @@ -32,6 +32,9 @@ def register_builtins(registry) -> None: registry: Registry to register into """ from .cuda import CudaBackend + from .impl.activation import silu_and_mul_cuda + from .impl.normalization import rms_norm_cuda + from .impl.rotary import rotary_embedding_cuda backend = CudaBackend() is_avail = backend.is_available @@ -42,7 +45,7 @@ def register_builtins(registry) -> None: op_name="silu_and_mul", impl_id="vendor.cuda", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.silu_and_mul, is_avail), + fn=_bind_is_available(silu_and_mul_cuda, is_avail), vendor="cuda", priority=BackendPriority.VENDOR, ), @@ -51,7 +54,7 @@ def register_builtins(registry) -> None: op_name="rms_norm", impl_id="vendor.cuda", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rms_norm, is_avail), + fn=_bind_is_available(rms_norm_cuda, is_avail), vendor="cuda", priority=BackendPriority.VENDOR, ), @@ -60,11 +63,11 @@ def register_builtins(registry) -> None: op_name="rotary_embedding", impl_id="vendor.cuda", kind=BackendImplKind.VENDOR, - fn=_bind_is_available(backend.rotary_embedding, is_avail), + fn=_bind_is_available(rotary_embedding_cuda, is_avail), vendor="cuda", priority=BackendPriority.VENDOR, ), - # Attention Backend + # Attention Backend (no instance binding needed) OpImpl( op_name="attention_backend", impl_id="vendor.cuda", diff --git a/vllm_fl/dispatch/config/cuda.yaml b/vllm_fl/dispatch/config/nvidia.yaml similarity index 91% rename from vllm_fl/dispatch/config/cuda.yaml rename to vllm_fl/dispatch/config/nvidia.yaml index 3c8676ff..2a192b14 100644 --- a/vllm_fl/dispatch/config/cuda.yaml +++ b/vllm_fl/dispatch/config/nvidia.yaml @@ -1,5 +1,5 @@ -# vLLM-FL Dispatch Configuration for CUDA -# Auto-loaded when running on NVIDIA GPU hardware +# vLLM-FL Dispatch Configuration for NVIDIA GPU +# Auto-loaded when running on NVIDIA GPU hardware (vendor_name: nvidia) # Preferred default backend type: flaggems, vendor, reference prefer: flagos diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py index d6597c0c..999a4324 100644 --- a/vllm_fl/dispatch/config/utils.py +++ b/vllm_fl/dispatch/config/utils.py @@ -41,26 +41,21 @@ def get_platform_name() -> str: """ - Detect the current hardware platform. + Detect the current hardware platform using platform vendor_name. + + This function uses current_platform.vendor_name to accurately distinguish + between different hardware vendors (NVIDIA, METAX, Ascend, etc.). Returns: - Platform name string: 'ascend', 'iluvatar', 'cuda', or 'unknown' + Platform name string based on vendor_name: 'nvidia', 'metax', 'ascend', etc. """ try: - import torch - if hasattr(torch, 'npu') and torch.npu.is_available(): - return 'ascend' - if "iluvatar" in torch.cuda.get_device_name().lower(): - return 'iluvatar' - if torch.cuda.is_available(): - return 'cuda' - except ImportError: - pass + from vllm.platforms import current_platform - # Check environment variable override - platform_override = os.environ.get('VLLM_FL_PLATFORM', '').strip().lower() - if platform_override: - return platform_override + if hasattr(current_platform, 'vendor_name'): + return current_platform.vendor_name + except Exception: + pass return 'unknown' diff --git a/vllm_fl/dispatch/manager.py b/vllm_fl/dispatch/manager.py index db75db4e..f950a11a 100644 --- a/vllm_fl/dispatch/manager.py +++ b/vllm_fl/dispatch/manager.py @@ -9,6 +9,7 @@ import logging import os import threading +import types as pytypes from dataclasses import dataclass from typing import Callable, Dict, Optional, Set, Tuple @@ -521,6 +522,115 @@ def call(self, op_name: str, *args, **kwargs): f"Last error: {last_error}" ) from last_error + def call_as_method(self, op_name: str, instance, *args, **kwargs): + """ + Resolve and call an operator as a bound method on *instance*. + + Behaves identically to :meth:`call` (fallback, logging, caching) + except that the resolved function is bound to *instance* via + ``types.MethodType`` before invocation, so the backend function + receives *instance* as ``self``. + """ + enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" + + if not enable_fallback: + fn = self.resolve(op_name) + + impl_id = self.get_selected_impl_id(op_name) + last_impl_id = self._called_ops.get(op_name) + + if last_impl_id != impl_id: + with self._lock: + if self._called_ops.get(op_name) != impl_id: + snap = self._registry.snapshot() + for impl in snap.impls_by_op.get(op_name, []): + if impl.impl_id == impl_id: + if last_impl_id is None: + logger.info( + f"Op '{op_name}' using '{impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + break + self._called_ops[op_name] = impl_id + + bound = pytypes.MethodType(fn, instance) + return bound(*args, **kwargs) + + # Fallback mode: try candidates in priority order + candidates = self.resolve_candidates(op_name) + last_error = None + + failed_impl_ids = self._failed_impls.get(op_name, set()) + + available_candidates = [ + impl for impl in candidates if impl.impl_id not in failed_impl_ids + ] + + if not available_candidates: + raise RuntimeError( + f"All implementations for op='{op_name}' have failed previously. " + f"Failed impl_ids: {failed_impl_ids}" + ) + + for idx, impl in enumerate(available_candidates): + try: + if idx == 0: + last_impl_id = self._called_ops.get(op_name) + if last_impl_id != impl.impl_id: + with self._lock: + if self._called_ops.get(op_name) != impl.impl_id: + if last_impl_id is None: + logger.info( + f"Op '{op_name}' using '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + else: + logger.info( + f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + self._called_ops[op_name] = impl.impl_id + else: + logger.info( + f"Op '{op_name}' fallback to '{impl.impl_id}' " + f"(kind={impl.kind.value}, vendor={impl.vendor})" + ) + + bound = pytypes.MethodType(impl.fn, instance) + result = bound(*args, **kwargs) + + if idx > 0: + with self._lock: + self._called_ops[op_name] = impl.impl_id + + return result + + except Exception as e: + last_error = e + with self._lock: + if op_name not in self._failed_impls: + self._failed_impls[op_name] = set() + self._failed_impls[op_name].add(impl.impl_id) + + if idx < len(available_candidates) - 1: + logger.warning( + f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + else: + logger.error( + f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}" + ) + + raise RuntimeError( + f"All {len(available_candidates)} implementation(s) failed for op='{op_name}'. " + f"Last error: {last_error}" + ) from last_error + def get_selected_impl_id(self, op_name: str) -> str: """ Get the impl_id of the currently selected implementation. diff --git a/vllm_fl/dispatch/method_dispatch.py b/vllm_fl/dispatch/method_dispatch.py new file mode 100644 index 00000000..bb8374e6 --- /dev/null +++ b/vllm_fl/dispatch/method_dispatch.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Descriptor-based method dispatch for operator implementations. + +Allows operator classes to declare `forward_oot` as a descriptor that +automatically dispatches to the resolved backend implementation, with +the backend function bound as a method so `self` is naturally available. +""" + +from __future__ import annotations + + +class dispatch_method: + """ + Descriptor that dispatches to the resolved backend implementation. + + The backend function is bound as a method to the operator instance + via ``types.MethodType``, so ``self`` is naturally available — just + like vLLM's ``forward_cuda`` / ``forward_xpu`` pattern. + + Usage:: + + class RMSNormFL(RMSNorm): + forward_oot = dispatch_method("rms_norm") + """ + + def __init__(self, op_name: str) -> None: + self.op_name = op_name + + def __set_name__(self, owner, name): + self.attr_name = name + + def __get__(self, obj, objtype=None): + if obj is None: + return self + + def dispatched(*args, **kwargs): + from vllm_fl.dispatch import get_default_manager + return get_default_manager().call_as_method( + self.op_name, obj, *args, **kwargs + ) + + return dispatched diff --git a/vllm_fl/ops/activation.py b/vllm_fl/ops/activation.py index e895c5b0..8032f6cb 100644 --- a/vllm_fl/ops/activation.py +++ b/vllm_fl/ops/activation.py @@ -1,15 +1,13 @@ # Copyright (c) 2025 BAAI. All rights reserved. -import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm_fl.dispatch import call_op +from vllm_fl.dispatch.method_dispatch import dispatch_method class SiluAndMulFL(SiluAndMul): def __init__(self): super().__init__() - def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - return call_op("silu_and_mul", self, x) + forward_oot = dispatch_method("silu_and_mul") __all__ = ["SiluAndMulFL"] diff --git a/vllm_fl/ops/layernorm.py b/vllm_fl/ops/layernorm.py index ea13c7c4..75c4c196 100644 --- a/vllm_fl/ops/layernorm.py +++ b/vllm_fl/ops/layernorm.py @@ -1,9 +1,9 @@ # Copyright (c) 2025 BAAI. All rights reserved. -from typing import Optional, Union +from typing import Optional import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm_fl.dispatch import call_op +from vllm_fl.dispatch.method_dispatch import dispatch_method class RMSNormFL(RMSNorm): @@ -17,12 +17,7 @@ def __init__( ) -> None: super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - def forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - return call_op("rms_norm", self, x, residual) + forward_oot = dispatch_method("rms_norm") __all__ = ["RMSNormFL"] diff --git a/vllm_fl/ops/rotary_embedding.py b/vllm_fl/ops/rotary_embedding.py index 98694bf9..a125f64c 100644 --- a/vllm_fl/ops/rotary_embedding.py +++ b/vllm_fl/ops/rotary_embedding.py @@ -3,7 +3,7 @@ from typing import Optional import torch from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm_fl.dispatch import call_op +from vllm_fl.dispatch import call_method_op class RotaryEmbeddingFL(RotaryEmbedding): @@ -44,7 +44,7 @@ def forward_oot( cos, sin = self.cos_sin_cache.chunk(2, dim=-1) - q_embed, k_embed = call_op( + q_embed, k_embed = call_method_op( "rotary_embedding", self, query_rot,