Decouple op implementations from Backend classes and add descriptor-based method dispatch#46
Decouple op implementations from Backend classes and add descriptor-based method dispatch#46xin2an wants to merge 50 commits intoflagos-ai:mainfrom
Conversation
…nfiguration file is now selected based on the current_platform.
…hod descriptor for instance-aware dispatch
There was a problem hiding this comment.
Pull request overview
This pull request refactors the operator dispatch system by decoupling operator implementations from Backend classes and introducing a descriptor-based method dispatch pattern. The change removes approximately 300 lines of boilerplate code across four backend classes (FlagGems, Reference, CUDA, Ascend) by registering standalone implementation functions directly instead of bound backend methods. Operator classes can now declare dispatch in a single line using a dispatch_method descriptor, mirroring vLLM's native forward_cuda/forward_xpu pattern.
Changes:
- Introduced
dispatch_methoddescriptor andcall_as_method()manager method for automatic method-level dispatch with instance binding - Removed wrapper methods (
silu_and_mul,rms_norm,rotary_embedding) from all Backend classes and updated registration to use standalone impl functions - Unified impl function signatures to use
selfinstead ofobjfor consistency with MethodType binding semantics - Updated platform detection to use
current_platform.vendor_nameinstead of torch availability checks - Simplified operator classes to use single-line descriptor declarations where appropriate
Reviewed changes
Copilot reviewed 19 out of 29 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_fl/dispatch/method_dispatch.py | New descriptor class for automatic method-level dispatch |
| vllm_fl/dispatch/manager.py | Added call_as_method() for invoking operators as bound methods with fallback support |
| vllm_fl/ops/activation.py | Simplified SiluAndMulFL to use dispatch_method descriptor |
| vllm_fl/ops/layernorm.py | Simplified RMSNormFL to use dispatch_method descriptor |
| vllm_fl/ops/rotary_embedding.py | Updated to use call_method_op() for manual dispatch with preprocessing logic |
| vllm_fl/dispatch/init.py | Added exports for dispatch_method and call_method_op |
| vllm_fl/dispatch/config/utils.py | Updated platform detection to use current_platform.vendor_name |
| vllm_fl/dispatch/config/nvidia.yaml | Updated comments to reflect vendor_name-based detection |
| vllm_fl/dispatch/README.md | Updated documentation to remove obsolete VLLM_FL_PLATFORM environment variable |
| vllm_fl/dispatch/backends/flaggems/flaggems.py | Removed operator wrapper methods (silu_and_mul, rms_norm, rotary_embedding) |
| vllm_fl/dispatch/backends/flaggems/register_ops.py | Updated to register standalone impl functions directly |
| vllm_fl/dispatch/backends/flaggems/impl/*.py | Changed first parameter from obj to self in all impl functions |
| vllm_fl/dispatch/backends/reference/reference.py | Removed operator wrapper methods |
| vllm_fl/dispatch/backends/reference/register_ops.py | Updated to register standalone impl functions directly |
| vllm_fl/dispatch/backends/reference/impl/*.py | Changed first parameter from obj to self in all impl functions |
| vllm_fl/dispatch/backends/vendor/cuda/cuda.py | Removed operator wrapper methods |
| vllm_fl/dispatch/backends/vendor/cuda/register_ops.py | Updated to register standalone impl functions directly |
| vllm_fl/dispatch/backends/vendor/cuda/impl/*.py | Changed first parameter from obj to self in all impl functions |
| vllm_fl/dispatch/backends/vendor/ascend/ascend.py | Removed operator wrapper methods |
| vllm_fl/dispatch/backends/vendor/ascend/register_ops.py | Updated to register standalone impl functions directly |
| vllm_fl/dispatch/backends/vendor/ascend/impl/*.py | Changed first parameter from obj to self in all impl functions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def call_as_method(self, op_name: str, instance, *args, **kwargs): | ||
| """ | ||
| Resolve and call an operator as a bound method on *instance*. | ||
|
|
||
| Behaves identically to :meth:`call` (fallback, logging, caching) | ||
| except that the resolved function is bound to *instance* via | ||
| ``types.MethodType`` before invocation, so the backend function | ||
| receives *instance* as ``self``. | ||
| """ | ||
| enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" | ||
|
|
||
| if not enable_fallback: | ||
| fn = self.resolve(op_name) | ||
|
|
||
| impl_id = self.get_selected_impl_id(op_name) | ||
| last_impl_id = self._called_ops.get(op_name) | ||
|
|
||
| if last_impl_id != impl_id: | ||
| with self._lock: | ||
| if self._called_ops.get(op_name) != impl_id: | ||
| snap = self._registry.snapshot() | ||
| for impl in snap.impls_by_op.get(op_name, []): | ||
| if impl.impl_id == impl_id: | ||
| if last_impl_id is None: | ||
| logger.info( | ||
| f"Op '{op_name}' using '{impl_id}' " | ||
| f"(kind={impl.kind.value}, vendor={impl.vendor})" | ||
| ) | ||
| else: | ||
| logger.info( | ||
| f"Op '{op_name}' switched from '{last_impl_id}' to '{impl_id}' " | ||
| f"(kind={impl.kind.value}, vendor={impl.vendor})" | ||
| ) | ||
| break | ||
| self._called_ops[op_name] = impl_id | ||
|
|
||
| bound = pytypes.MethodType(fn, instance) | ||
| return bound(*args, **kwargs) | ||
|
|
||
| # Fallback mode: try candidates in priority order | ||
| candidates = self.resolve_candidates(op_name) | ||
| last_error = None | ||
|
|
||
| failed_impl_ids = self._failed_impls.get(op_name, set()) | ||
|
|
||
| available_candidates = [ | ||
| impl for impl in candidates if impl.impl_id not in failed_impl_ids | ||
| ] | ||
|
|
||
| if not available_candidates: | ||
| raise RuntimeError( | ||
| f"All implementations for op='{op_name}' have failed previously. " | ||
| f"Failed impl_ids: {failed_impl_ids}" | ||
| ) | ||
|
|
||
| for idx, impl in enumerate(available_candidates): | ||
| try: | ||
| if idx == 0: | ||
| last_impl_id = self._called_ops.get(op_name) | ||
| if last_impl_id != impl.impl_id: | ||
| with self._lock: | ||
| if self._called_ops.get(op_name) != impl.impl_id: | ||
| if last_impl_id is None: | ||
| logger.info( | ||
| f"Op '{op_name}' using '{impl.impl_id}' " | ||
| f"(kind={impl.kind.value}, vendor={impl.vendor})" | ||
| ) | ||
| else: | ||
| logger.info( | ||
| f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' " | ||
| f"(kind={impl.kind.value}, vendor={impl.vendor})" | ||
| ) | ||
| self._called_ops[op_name] = impl.impl_id | ||
| else: | ||
| logger.info( | ||
| f"Op '{op_name}' fallback to '{impl.impl_id}' " | ||
| f"(kind={impl.kind.value}, vendor={impl.vendor})" | ||
| ) | ||
|
|
||
| bound = pytypes.MethodType(impl.fn, instance) | ||
| result = bound(*args, **kwargs) | ||
|
|
||
| if idx > 0: | ||
| with self._lock: | ||
| self._called_ops[op_name] = impl.impl_id | ||
|
|
||
| return result | ||
|
|
||
| except Exception as e: | ||
| last_error = e | ||
| with self._lock: | ||
| if op_name not in self._failed_impls: | ||
| self._failed_impls[op_name] = set() | ||
| self._failed_impls[op_name].add(impl.impl_id) | ||
|
|
||
| if idx < len(available_candidates) - 1: | ||
| logger.warning( | ||
| f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}" | ||
| ) | ||
| else: | ||
| logger.error( | ||
| f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}" | ||
| ) | ||
|
|
||
| raise RuntimeError( | ||
| f"All {len(available_candidates)} implementation(s) failed for op='{op_name}'. " | ||
| f"Last error: {last_error}" | ||
| ) from last_error |
There was a problem hiding this comment.
The call_as_method method contains significant code duplication with the call method. The only difference is binding the function using pytypes.MethodType before invocation (lines 561, 604). This duplicates approximately 108 lines of complex fallback, logging, and error handling logic.
Consider refactoring to avoid this duplication. One approach would be to extract a common method that handles the fallback logic and accepts a callback for function invocation, or add a parameter to the existing call method to enable method binding.
| if platform_override: | ||
| return platform_override | ||
| if hasattr(current_platform, 'vendor_name'): | ||
| return current_platform.vendor_name |
There was a problem hiding this comment.
The platform detection now relies on current_platform.vendor_name without case normalization. If vLLM's vendor_name returns values with different casing (e.g., "NVIDIA" instead of "nvidia"), the config file lookup at line 77 (_CONFIG_DIR / f"{platform}.yaml") will fail since the config file is named nvidia.yaml (lowercase).
Consider adding case normalization to ensure robust platform detection, for example:
if hasattr(current_platform, 'vendor_name'):
return current_platform.vendor_name.lower()| return current_platform.vendor_name | |
| vendor = current_platform.vendor_name | |
| # Normalize to lowercase to match config filenames (e.g., nvidia.yaml). | |
| return vendor.lower() if isinstance(vendor, str) else str(vendor).lower() |
Add check for 'iluvatar' vendor in CUDA availability methods.
…lugin-FL into dispatch_add_v0130
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 29 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| dtype: Optional[torch.dtype] = None, | ||
| ) -> None: | ||
| super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) | ||
|
|
||
| def forward_oot( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | ||
| return call_op("rms_norm", self, x, residual) | ||
| forward_oot = dispatch_method("rms_norm") |
There was a problem hiding this comment.
Unit tests currently patch vllm_fl.ops.layernorm.call_op, but this module no longer imports/uses call_op after switching forward_oot to dispatch_method. Please update the tests to mock the new dispatch path (e.g., OpManager.call_as_method) and verify arguments reflect method binding (no explicit self passed to the backend function).
| cos, sin = self.cos_sin_cache.chunk(2, dim=-1) | ||
|
|
||
| q_embed, k_embed = call_op( | ||
| q_embed, k_embed = call_method_op( | ||
| "rotary_embedding", | ||
| self, | ||
| query_rot, |
There was a problem hiding this comment.
Unit tests currently patch vllm_fl.ops.rotary_embedding.call_op, but this implementation now calls call_method_op. Please update the tests to patch/mock call_method_op (or OpManager.call_as_method) and adjust the expected call signature accordingly.
| enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" | ||
|
|
There was a problem hiding this comment.
call_as_method gates fallback behavior using VLLM_FL_STRICT in a way that appears inverted relative to the documented/configured meaning of “strict” (YAML/policy strict=true => no fallback). As written, the default VLLM_FL_STRICT=1 enables fallback unconditionally and ignores the SelectionPolicy.strict value. Please align call_as_method (and ideally call) with get_policy().strict so strict mode truly disables fallback, and make the env/config semantics consistent.
| enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" | |
| # Determine strictness: env var overrides policy; strict=True => no fallback. | |
| strict_env = os.getenv("VLLM_FL_STRICT") | |
| if strict_env is not None: | |
| # Treat "0"/"false"/"False" as non-strict; anything else as strict. | |
| strict = strict_env not in ("0", "false", "False") | |
| else: | |
| policy = get_policy() | |
| strict = bool(getattr(policy, "strict", False)) if policy is not None else False | |
| enable_fallback = not strict |
| def call_as_method(self, op_name: str, instance, *args, **kwargs): | ||
| """ | ||
| Resolve and call an operator as a bound method on *instance*. | ||
|
|
||
| Behaves identically to :meth:`call` (fallback, logging, caching) | ||
| except that the resolved function is bound to *instance* via | ||
| ``types.MethodType`` before invocation, so the backend function | ||
| receives *instance* as ``self``. | ||
| """ |
There was a problem hiding this comment.
call_as_method largely duplicates the implementation of call (logging, fallback loop, failure tracking). This increases maintenance risk (future fixes can diverge between the two). Consider refactoring to share a single internal helper that performs resolve/fallback and accepts an invocation strategy (direct call vs MethodType binding).
| def __set_name__(self, owner, name): | ||
| self.attr_name = name | ||
|
|
||
| def __get__(self, obj, objtype=None): | ||
| if obj is None: | ||
| return self | ||
|
|
||
| def dispatched(*args, **kwargs): | ||
| from vllm_fl.dispatch import get_default_manager | ||
| return get_default_manager().call_as_method( | ||
| self.op_name, obj, *args, **kwargs | ||
| ) | ||
|
|
There was a problem hiding this comment.
dispatch_method.__set_name__ stores attr_name, but it isn’t used. If you intend to cache the generated callable on the instance to avoid allocating a new closure each attribute access (and to reduce per-call overhead), use attr_name for that; otherwise remove __set_name__/attr_name to avoid dead code.
| def __set_name__(self, owner, name): | |
| self.attr_name = name | |
| def __get__(self, obj, objtype=None): | |
| if obj is None: | |
| return self | |
| def dispatched(*args, **kwargs): | |
| from vllm_fl.dispatch import get_default_manager | |
| return get_default_manager().call_as_method( | |
| self.op_name, obj, *args, **kwargs | |
| ) | |
| def __set_name__(self, owner, name): | |
| # Store the attribute name so we can cache the dispatched callable | |
| # on the instance after first access. | |
| self.attr_name = name | |
| def __get__(self, obj, objtype=None): | |
| if obj is None: | |
| return self | |
| # If the instance has a __dict__, try to reuse a cached dispatched | |
| # callable instead of allocating a new closure on every access. | |
| if hasattr(obj, "__dict__"): | |
| cached = obj.__dict__.get(self.attr_name) | |
| if cached is not None: | |
| return cached | |
| def dispatched(*args, **kwargs): | |
| from vllm_fl.dispatch import get_default_manager | |
| return get_default_manager().call_as_method( | |
| self.op_name, obj, *args, **kwargs | |
| ) | |
| # Cache the dispatched callable on the instance (when possible) so | |
| # subsequent lookups bypass the descriptor and reuse the same closure. | |
| if hasattr(obj, "__dict__"): | |
| obj.__dict__[self.attr_name] = dispatched |
| class SiluAndMulFL(SiluAndMul): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward_oot(self, x: torch.Tensor) -> torch.Tensor: | ||
| return call_op("silu_and_mul", self, x) | ||
| forward_oot = dispatch_method("silu_and_mul") |
There was a problem hiding this comment.
Unit tests currently patch vllm_fl.ops.activation.call_op, but this module no longer imports/uses call_op after switching forward_oot to dispatch_method. Please update the tests to mock the new dispatch path (e.g., OpManager.call_as_method / call_method_op) and assert the new call signature (instance is bound rather than passed explicitly).
| @@ -44,82 +42,6 @@ def is_available(self) -> bool: | |||
|
|
|||
| # ==================== Operator Implementations ==================== | |||
|
|
|||
Summary
Extracts operator implementations (
silu_and_mul,rms_norm,rotary_embedding) from Backend classes into standalone functions and introduces adispatch_methoddescriptor for automatic method-level dispatch. Operator classes can now declare dispatch binding in a single line, mirroring vLLM's nativeforward_cuda/forward_xpupattern.Motivation
Previously, every Backend class (FlagGems, Reference, CUDA, Ascend) had to define full operator methods that were simple pass-through wrappers around standalone functions in
impl/. This resulted in significant boilerplate duplication across backends. Additionally, operator classes (e.g.,RMSNormFL) had to manually callcall_opand explicitly passself, which was unnatural and verbose.Changes
New modules
vllm_fl/dispatch/method_dispatch.py: Adispatch_methoddescriptor that uses the Python descriptor protocol to automatically dispatchforward_ootcalls to the resolved backend implementation, binding the backend function as a method viaMethodTypesoselfis naturally available.vllm_fl/dispatch/manager.py: AddedOpManager.call_as_method()to invoke an operator as a bound method on an instance.Backend class cleanup
silu_and_mul,rms_norm, androtary_embeddingmethods fromFlagGemsBackend,ReferenceBackend,CudaBackend, andAscendBackend(~300 lines of boilerplate removed).Registration changes
register_ops.pynow registers standaloneimpl/functions directly instead of bound backend instance methods.Unified impl function signatures
objtoself, consistent withMethodTypebinding semantics.Simplified operator classes
SiluAndMulFL,RMSNormFL, andRotaryEmbeddingFLnow use a single-line descriptor declaration for dispatch:Public API
dispatch_methodandcall_method_optovllm_fl.dispatchexports.Impact
impl/and register it inregister_ops.py