Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion vllm_fl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def register():
def register_model():
"""Register the FL model."""
from vllm import ModelRegistry
import vllm.model_executor.models.qwen3_next as qwen3_next_module

# Register Qwen3.5 MoE config
try:
Expand Down
3 changes: 2 additions & 1 deletion vllm_fl/compilation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class Graph:
elif current_platform.device_type == "npu":
graph = torch.npu.NPUGraph
else:
raise NotImplementedError("not support graph")
pass
# raise NotImplementedError("not support graph")

@dataclasses.dataclass
class GraphEntry:
Expand Down
9 changes: 9 additions & 0 deletions vllm_fl/dispatch/backends/vendor/txda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2026 BAAI. All rights reserved.

"""
Txda (Tsingmicro) backend for vllm-plugin-FL dispatch.
"""

from .txda import TxdaBackend

__all__ = ["TxadBackend"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TxdaBackend typo

Check failure

Code scanning / CodeQL

Explicit export is not defined Error library

The name 'TxadBackend' is exported by __all__ but is not defined.
78 changes: 78 additions & 0 deletions vllm_fl/dispatch/backends/vendor/txda/register_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2026 BAAI. All rights reserved.

"""
METAX backend operator registrations.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
METAX backend operator registrations.
TsingMicro backend operator registrations.


This module registers all VENDOR (METAX) implementations.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This module registers all VENDOR (METAX) implementations.
This module registers all VENDOR (TsingMicro) implementations.

"""

from __future__ import annotations

import functools

from vllm_fl.dispatch.types import OpImpl, BackendImplKind, BackendPriority


def _bind_is_available(fn, is_available_fn):
"""Wrap a function and bind _is_available attribute for OpImpl.is_available() check."""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)

wrapper._is_available = is_available_fn
return wrapper


def register_builtins(registry) -> None:
"""
Register all METAX (VENDOR) operator implementations.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Register all METAX (VENDOR) operator implementations.
Register all TsingMicro (VENDOR) operator implementations.


Args:
registry: Registry to register into
"""
from .txda import TxdaBackend

backend = TxdaBackend()
is_avail = backend.is_available

impls = [
# # Activation
# OpImpl(
# op_name="silu_and_mul",
# impl_id="vendor.txda",
# kind=BackendImplKind.VENDOR,
# fn=_bind_is_available(backend.silu_and_mul, is_avail),
# vendor="txda",
# priority=BackendPriority.VENDOR,
# ),
# # Normalization
# OpImpl(
# op_name="rms_norm",
# impl_id="vendor.txda",
# kind=BackendImplKind.VENDOR,
# fn=_bind_is_available(backend.rms_norm, is_avail),
# vendor="txda",
# priority=BackendPriority.VENDOR,
# ),
# # Rotary Embedding
# OpImpl(
# op_name="rotary_embedding",
# impl_id="vendor.txda",
# kind=BackendImplKind.VENDOR,
# fn=_bind_is_available(backend.rotary_embedding, is_avail),
# vendor="txda",
# priority=BackendPriority.VENDOR,
# ),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we dropping the other implementations?

# Attention Backend
OpImpl(
op_name="attention_backend",
impl_id="vendor.txda",
kind=BackendImplKind.VENDOR,
fn=_bind_is_available(backend.attention_backend, is_avail),
vendor="txda",
priority=BackendPriority.VENDOR,
),
]

registry.register_many(impls)
150 changes: 150 additions & 0 deletions vllm_fl/dispatch/backends/vendor/txda/txda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2026 BAAI. All rights reserved.

"""
Txda backend implementation.

This backend provides operator implementations for Tsingmiocro Txda NPUs.
"""

from __future__ import annotations

from typing import Optional, Union

Check notice

Code scanning / CodeQL

Unused import Note library

Import of 'Union' is not used.

import torch

# from vllm_fl.dispatch.backends.flaggems import FlagGemsBackend
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

from vllm_fl.dispatch.backends.base import Backend


class TxdaBackend(Backend):
"""
Txda backend for operator implementations.

This backend uses Txda CANN libraries to provide high-performance
operator implementations for Tsingmiocro Txda NPUs.
"""

_available: Optional[bool] = None

@property
def name(self) -> str:
return "txda"

@property
def vendor(self) -> Optional[str]:
return "txda"

def is_available(self) -> bool:
"""Check if Txda hardware and libraries are available."""
if TxdaBackend._available is None:
try:
# Check for torch_npu (Txda PyTorch extension)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check for torch_npu (Txda PyTorch extension)
# Check for torch_txda (Txda PyTorch extension)

import torch_txda

# Check if NPU device is available
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check if NPU device is available
# Check if TsingMicro txda device is available

if torch.txda.is_available() and torch.txda.device_count() > 0:
TxdaBackend._available = True
else:
TxdaBackend._available = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird that we change class properties in an instance method.
You may want to annotate this method with @classmethod.

except (ImportError, AttributeError):
TxdaBackend._available = False
return TxdaBackend._available

# ==================== Operator Implementations ====================

# def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor:
# """
# SiLU activation followed by element-wise multiplication.

# Args:
# obj: The calling obj (for interface consistency)
# x: Input tensor of shape [..., 2*d]

# Returns:
# Output tensor of shape [..., d]
# """
# from .impl.activation import silu_and_mul_Txda

# return silu_and_mul_Txda(obj, x)

# def rms_norm(
# self,
# obj,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# """
# RMS normalization.

# Args:
# obj: The calling obj (e.g., RMSNorm layer)
# x: Input tensor
# residual: Optional residual tensor

# Returns:
# Normalized tensor, or tuple of (normalized, residual) if residual is provided
# """
# from .impl.normalization import rms_norm_Txda

# return rms_norm_Txda(obj, x, residual)

# def rotary_embedding(
# self,
# obj,
# query: torch.Tensor,
# key: torch.Tensor,
# cos: torch.Tensor,
# sin: torch.Tensor,
# position_ids: torch.Tensor,
# rotary_interleaved: bool = False,
# inplace: bool = True,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Apply rotary position embedding.

# Args:
# obj: The calling obj (for interface consistency)
# query: Query tensor
# key: Key tensor
# cos: Cosine cache
# sin: Sine cache
# position_ids: Position indices
# rotary_interleaved: Whether to use interleaved rotary
# inplace: Whether to modify tensors in-place

# Returns:
# Tuple of (embedded_query, embedded_key)
# """
# from .impl.rotary import rotary_embedding_Txda

# return rotary_embedding_Txda(
# obj,
# query,
# key,
# cos,
# sin,
# position_ids,
# rotary_interleaved=rotary_interleaved,
# inplace=inplace,
# )

def attention_backend(self, use_mla: bool = False) -> str:
"""
Get the attention backend class path for Txda NPU.

This method returns the native Txda attention backend that uses
torch_npu operators (npu_fused_infer_attention_score, etc.)
instead of flag_gems operators.

Uses vllm_fl's native Txda implementation which directly calls
torch_npu operators without depending on vllm-Txda package.

Args:
use_mla: Whether to use Multi-head Latent Attention (MLA)

Returns:
Fully qualified class path string
"""
if use_mla:
return "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend"
return "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend"
16 changes: 14 additions & 2 deletions vllm_fl/distributed/device_communicators/flagcx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
### TODO(lms): simplify it
if library_path is None:
flagcx_path = os.getenv('FLAGCX_PATH')
library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so")
#library_path=os.path.join(flagcx_path, "libflagcx.so") # rcy fix
library_path= "/usr/local/kuiper/lib/libflagcx.so"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using hard-coded path rather than using the environment variable?

self.flagcx = FLAGCXLibrary(library_path)
else:
self.flagcx = FLAGCXLibrary(library_path)
Expand Down Expand Up @@ -113,7 +114,8 @@
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
#
with torch.txda.device(device):
self.comm = self.flagcx.flagcxCommInitRank(
self.world_size, ctypes.byref(self.unique_id), self.rank)

Expand Down Expand Up @@ -144,13 +146,23 @@
if stream is None:
stream = current_stream()
flagcx_stream = self.flagcx.adaptor_stream_copy(stream)
change_type = False
if in_tensor.dtype == torch.bfloat16:
in_tensor = in_tensor.to(torch.float32)
out_tensor = out_tensor.to(torch.float32)
change_type = True

self.flagcx.flagcxAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
flagcxDataTypeEnum.from_torch(in_tensor.dtype),
flagcxRedOpTypeEnum.from_torch(op), self.comm,
flagcx_stream)
self.flagcx.adaptor_stream_free(flagcx_stream)
if change_type:
in_tensor = in_tensor.to(torch.bfloat16)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable in_tensor is not used.
out_tensor = out_tensor.to(torch.bfloat16)

return out_tensor

def all_gather(self,
Expand Down
Loading
Loading