Skip to content

Decouple op implementations from Backend classes and add descriptor-based method dispatch#46

Open
xin2an wants to merge 50 commits intoflagos-ai:mainfrom
xin2an:dispatch_add_v0130
Open

Decouple op implementations from Backend classes and add descriptor-based method dispatch#46
xin2an wants to merge 50 commits intoflagos-ai:mainfrom
xin2an:dispatch_add_v0130

Conversation

@xin2an
Copy link
Copy Markdown
Contributor

@xin2an xin2an commented Feb 11, 2026

Summary

Extracts operator implementations (silu_and_mul, rms_norm, rotary_embedding) from Backend classes into standalone functions and introduces a dispatch_method descriptor for automatic method-level dispatch. Operator classes can now declare dispatch binding in a single line, mirroring vLLM's native forward_cuda / forward_xpu pattern.

Motivation

Previously, every Backend class (FlagGems, Reference, CUDA, Ascend) had to define full operator methods that were simple pass-through wrappers around standalone functions in impl/. This resulted in significant boilerplate duplication across backends. Additionally, operator classes (e.g., RMSNormFL) had to manually call call_op and explicitly pass self, which was unnatural and verbose.

Changes

New modules

  • vllm_fl/dispatch/method_dispatch.py: A dispatch_method descriptor that uses the Python descriptor protocol to automatically dispatch forward_oot calls to the resolved backend implementation, binding the backend function as a method via MethodType so self is naturally available.
  • vllm_fl/dispatch/manager.py: Added OpManager.call_as_method() to invoke an operator as a bound method on an instance.

Backend class cleanup

  • Removed silu_and_mul, rms_norm, and rotary_embedding methods from FlagGemsBackend, ReferenceBackend, CudaBackend, and AscendBackend (~300 lines of boilerplate removed).

Registration changes

  • Each backend's register_ops.py now registers standalone impl/ functions directly instead of bound backend instance methods.

Unified impl function signatures

  • Renamed the first parameter of all impl functions from obj to self, consistent with MethodType binding semantics.

Simplified operator classes

  • SiluAndMulFL, RMSNormFL, and RotaryEmbeddingFL now use a single-line descriptor declaration for dispatch:
    forward_oot = dispatch_method("rms_norm")

Public API

  • Added dispatch_method and call_method_op to vllm_fl.dispatch exports.

Impact

  • Removes ~300 lines of duplicated boilerplate code
  • Adding new operators no longer requires modifying Backend classes — just write a standalone function in impl/ and register it in register_ops.py
  • Operator class dispatch declaration reduced from multi-line manual calls to a single-line descriptor

ceci3 and others added 30 commits December 31, 2025 20:36
@xin2an xin2an changed the title Delete the VLLM_FL_PLATFORM environment variable Decouple op implementations from Backend classes and add descriptor-based method dispatch Feb 12, 2026
Copilot AI review requested due to automatic review settings February 24, 2026 01:40
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors the operator dispatch system by decoupling operator implementations from Backend classes and introducing a descriptor-based method dispatch pattern. The change removes approximately 300 lines of boilerplate code across four backend classes (FlagGems, Reference, CUDA, Ascend) by registering standalone implementation functions directly instead of bound backend methods. Operator classes can now declare dispatch in a single line using a dispatch_method descriptor, mirroring vLLM's native forward_cuda/forward_xpu pattern.

Changes:

  • Introduced dispatch_method descriptor and call_as_method() manager method for automatic method-level dispatch with instance binding
  • Removed wrapper methods (silu_and_mul, rms_norm, rotary_embedding) from all Backend classes and updated registration to use standalone impl functions
  • Unified impl function signatures to use self instead of obj for consistency with MethodType binding semantics
  • Updated platform detection to use current_platform.vendor_name instead of torch availability checks
  • Simplified operator classes to use single-line descriptor declarations where appropriate

Reviewed changes

Copilot reviewed 19 out of 29 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
vllm_fl/dispatch/method_dispatch.py New descriptor class for automatic method-level dispatch
vllm_fl/dispatch/manager.py Added call_as_method() for invoking operators as bound methods with fallback support
vllm_fl/ops/activation.py Simplified SiluAndMulFL to use dispatch_method descriptor
vllm_fl/ops/layernorm.py Simplified RMSNormFL to use dispatch_method descriptor
vllm_fl/ops/rotary_embedding.py Updated to use call_method_op() for manual dispatch with preprocessing logic
vllm_fl/dispatch/init.py Added exports for dispatch_method and call_method_op
vllm_fl/dispatch/config/utils.py Updated platform detection to use current_platform.vendor_name
vllm_fl/dispatch/config/nvidia.yaml Updated comments to reflect vendor_name-based detection
vllm_fl/dispatch/README.md Updated documentation to remove obsolete VLLM_FL_PLATFORM environment variable
vllm_fl/dispatch/backends/flaggems/flaggems.py Removed operator wrapper methods (silu_and_mul, rms_norm, rotary_embedding)
vllm_fl/dispatch/backends/flaggems/register_ops.py Updated to register standalone impl functions directly
vllm_fl/dispatch/backends/flaggems/impl/*.py Changed first parameter from obj to self in all impl functions
vllm_fl/dispatch/backends/reference/reference.py Removed operator wrapper methods
vllm_fl/dispatch/backends/reference/register_ops.py Updated to register standalone impl functions directly
vllm_fl/dispatch/backends/reference/impl/*.py Changed first parameter from obj to self in all impl functions
vllm_fl/dispatch/backends/vendor/cuda/cuda.py Removed operator wrapper methods
vllm_fl/dispatch/backends/vendor/cuda/register_ops.py Updated to register standalone impl functions directly
vllm_fl/dispatch/backends/vendor/cuda/impl/*.py Changed first parameter from obj to self in all impl functions
vllm_fl/dispatch/backends/vendor/ascend/ascend.py Removed operator wrapper methods
vllm_fl/dispatch/backends/vendor/ascend/register_ops.py Updated to register standalone impl functions directly
vllm_fl/dispatch/backends/vendor/ascend/impl/*.py Changed first parameter from obj to self in all impl functions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +525 to +632
def call_as_method(self, op_name: str, instance, *args, **kwargs):
"""
Resolve and call an operator as a bound method on *instance*.

Behaves identically to :meth:`call` (fallback, logging, caching)
except that the resolved function is bound to *instance* via
``types.MethodType`` before invocation, so the backend function
receives *instance* as ``self``.
"""
enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0"

if not enable_fallback:
fn = self.resolve(op_name)

impl_id = self.get_selected_impl_id(op_name)
last_impl_id = self._called_ops.get(op_name)

if last_impl_id != impl_id:
with self._lock:
if self._called_ops.get(op_name) != impl_id:
snap = self._registry.snapshot()
for impl in snap.impls_by_op.get(op_name, []):
if impl.impl_id == impl_id:
if last_impl_id is None:
logger.info(
f"Op '{op_name}' using '{impl_id}' "
f"(kind={impl.kind.value}, vendor={impl.vendor})"
)
else:
logger.info(
f"Op '{op_name}' switched from '{last_impl_id}' to '{impl_id}' "
f"(kind={impl.kind.value}, vendor={impl.vendor})"
)
break
self._called_ops[op_name] = impl_id

bound = pytypes.MethodType(fn, instance)
return bound(*args, **kwargs)

# Fallback mode: try candidates in priority order
candidates = self.resolve_candidates(op_name)
last_error = None

failed_impl_ids = self._failed_impls.get(op_name, set())

available_candidates = [
impl for impl in candidates if impl.impl_id not in failed_impl_ids
]

if not available_candidates:
raise RuntimeError(
f"All implementations for op='{op_name}' have failed previously. "
f"Failed impl_ids: {failed_impl_ids}"
)

for idx, impl in enumerate(available_candidates):
try:
if idx == 0:
last_impl_id = self._called_ops.get(op_name)
if last_impl_id != impl.impl_id:
with self._lock:
if self._called_ops.get(op_name) != impl.impl_id:
if last_impl_id is None:
logger.info(
f"Op '{op_name}' using '{impl.impl_id}' "
f"(kind={impl.kind.value}, vendor={impl.vendor})"
)
else:
logger.info(
f"Op '{op_name}' switched from '{last_impl_id}' to '{impl.impl_id}' "
f"(kind={impl.kind.value}, vendor={impl.vendor})"
)
self._called_ops[op_name] = impl.impl_id
else:
logger.info(
f"Op '{op_name}' fallback to '{impl.impl_id}' "
f"(kind={impl.kind.value}, vendor={impl.vendor})"
)

bound = pytypes.MethodType(impl.fn, instance)
result = bound(*args, **kwargs)

if idx > 0:
with self._lock:
self._called_ops[op_name] = impl.impl_id

return result

except Exception as e:
last_error = e
with self._lock:
if op_name not in self._failed_impls:
self._failed_impls[op_name] = set()
self._failed_impls[op_name].add(impl.impl_id)

if idx < len(available_candidates) - 1:
logger.warning(
f"Implementation '{impl.impl_id}' failed for op '{op_name}': {e}"
)
else:
logger.error(
f"Last implementation '{impl.impl_id}' failed for op '{op_name}': {e}"
)

raise RuntimeError(
f"All {len(available_candidates)} implementation(s) failed for op='{op_name}'. "
f"Last error: {last_error}"
) from last_error
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call_as_method method contains significant code duplication with the call method. The only difference is binding the function using pytypes.MethodType before invocation (lines 561, 604). This duplicates approximately 108 lines of complex fallback, logging, and error handling logic.

Consider refactoring to avoid this duplication. One approach would be to extract a common method that handles the fallback logic and accepts a callback for function invocation, or add a parameter to the existing call method to enable method binding.

Copilot uses AI. Check for mistakes.
if platform_override:
return platform_override
if hasattr(current_platform, 'vendor_name'):
return current_platform.vendor_name
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The platform detection now relies on current_platform.vendor_name without case normalization. If vLLM's vendor_name returns values with different casing (e.g., "NVIDIA" instead of "nvidia"), the config file lookup at line 77 (_CONFIG_DIR / f"{platform}.yaml") will fail since the config file is named nvidia.yaml (lowercase).

Consider adding case normalization to ensure robust platform detection, for example:

if hasattr(current_platform, 'vendor_name'):
    return current_platform.vendor_name.lower()
Suggested change
return current_platform.vendor_name
vendor = current_platform.vendor_name
# Normalize to lowercase to match config filenames (e.g., nvidia.yaml).
return vendor.lower() if isinstance(vendor, str) else str(vendor).lower()

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 28, 2026 11:33
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 29 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 16 to +20
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)

def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return call_op("rms_norm", self, x, residual)
forward_oot = dispatch_method("rms_norm")
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests currently patch vllm_fl.ops.layernorm.call_op, but this module no longer imports/uses call_op after switching forward_oot to dispatch_method. Please update the tests to mock the new dispatch path (e.g., OpManager.call_as_method) and verify arguments reflect method binding (no explicit self passed to the backend function).

Copilot uses AI. Check for mistakes.
Comment on lines 45 to 50
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)

q_embed, k_embed = call_op(
q_embed, k_embed = call_method_op(
"rotary_embedding",
self,
query_rot,
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests currently patch vllm_fl.ops.rotary_embedding.call_op, but this implementation now calls call_method_op. Please update the tests to patch/mock call_method_op (or OpManager.call_as_method) and adjust the expected call signature accordingly.

Copilot uses AI. Check for mistakes.
Comment on lines +534 to +535
enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0"

Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_as_method gates fallback behavior using VLLM_FL_STRICT in a way that appears inverted relative to the documented/configured meaning of “strict” (YAML/policy strict=true => no fallback). As written, the default VLLM_FL_STRICT=1 enables fallback unconditionally and ignores the SelectionPolicy.strict value. Please align call_as_method (and ideally call) with get_policy().strict so strict mode truly disables fallback, and make the env/config semantics consistent.

Suggested change
enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0"
# Determine strictness: env var overrides policy; strict=True => no fallback.
strict_env = os.getenv("VLLM_FL_STRICT")
if strict_env is not None:
# Treat "0"/"false"/"False" as non-strict; anything else as strict.
strict = strict_env not in ("0", "false", "False")
else:
policy = get_policy()
strict = bool(getattr(policy, "strict", False)) if policy is not None else False
enable_fallback = not strict

Copilot uses AI. Check for mistakes.
Comment on lines +525 to +533
def call_as_method(self, op_name: str, instance, *args, **kwargs):
"""
Resolve and call an operator as a bound method on *instance*.

Behaves identically to :meth:`call` (fallback, logging, caching)
except that the resolved function is bound to *instance* via
``types.MethodType`` before invocation, so the backend function
receives *instance* as ``self``.
"""
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_as_method largely duplicates the implementation of call (logging, fallback loop, failure tracking). This increases maintenance risk (future fixes can diverge between the two). Consider refactoring to share a single internal helper that performs resolve/fallback and accepts an invocation strategy (direct call vs MethodType binding).

Copilot uses AI. Check for mistakes.
Comment on lines +31 to +43
def __set_name__(self, owner, name):
self.attr_name = name

def __get__(self, obj, objtype=None):
if obj is None:
return self

def dispatched(*args, **kwargs):
from vllm_fl.dispatch import get_default_manager
return get_default_manager().call_as_method(
self.op_name, obj, *args, **kwargs
)

Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch_method.__set_name__ stores attr_name, but it isn’t used. If you intend to cache the generated callable on the instance to avoid allocating a new closure each attribute access (and to reduce per-call overhead), use attr_name for that; otherwise remove __set_name__/attr_name to avoid dead code.

Suggested change
def __set_name__(self, owner, name):
self.attr_name = name
def __get__(self, obj, objtype=None):
if obj is None:
return self
def dispatched(*args, **kwargs):
from vllm_fl.dispatch import get_default_manager
return get_default_manager().call_as_method(
self.op_name, obj, *args, **kwargs
)
def __set_name__(self, owner, name):
# Store the attribute name so we can cache the dispatched callable
# on the instance after first access.
self.attr_name = name
def __get__(self, obj, objtype=None):
if obj is None:
return self
# If the instance has a __dict__, try to reuse a cached dispatched
# callable instead of allocating a new closure on every access.
if hasattr(obj, "__dict__"):
cached = obj.__dict__.get(self.attr_name)
if cached is not None:
return cached
def dispatched(*args, **kwargs):
from vllm_fl.dispatch import get_default_manager
return get_default_manager().call_as_method(
self.op_name, obj, *args, **kwargs
)
# Cache the dispatched callable on the instance (when possible) so
# subsequent lookups bypass the descriptor and reuse the same closure.
if hasattr(obj, "__dict__"):
obj.__dict__[self.attr_name] = dispatched

Copilot uses AI. Check for mistakes.
Comment on lines 7 to +11
class SiluAndMulFL(SiluAndMul):
def __init__(self):
super().__init__()

def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
return call_op("silu_and_mul", self, x)
forward_oot = dispatch_method("silu_and_mul")
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests currently patch vllm_fl.ops.activation.call_op, but this module no longer imports/uses call_op after switching forward_oot to dispatch_method. Please update the tests to mock the new dispatch path (e.g., OpManager.call_as_method / call_method_op) and assert the new call signature (instance is bound rather than passed explicitly).

Copilot uses AI. Check for mistakes.
@@ -44,82 +42,6 @@ def is_available(self) -> bool:

# ==================== Operator Implementations ====================

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why delete here?

return platform_override
if hasattr(current_platform, 'vendor_name'):
return current_platform.vendor_name
except Exception:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants