-
Notifications
You must be signed in to change notification settings - Fork 40
Tsingmicro add txda backend to vllm-fl plugin. #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
96a23a1
befc101
c716683
56e2f3f
6cec619
ac17633
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # Copyright (c) 2026 BAAI. All rights reserved. | ||
|
|
||
| """ | ||
| Txda (Tsingmicro) backend for vllm-plugin-FL dispatch. | ||
| """ | ||
|
|
||
| from .txda import TxdaBackend | ||
|
|
||
| __all__ = ["TxadBackend"] | ||
Check failureCode scanning / CodeQL Explicit export is not defined Error library
The name 'TxadBackend' is exported by __all__ but is not defined.
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,78 @@ | ||||||
| # Copyright (c) 2026 BAAI. All rights reserved. | ||||||
|
|
||||||
| """ | ||||||
| METAX backend operator registrations. | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| This module registers all VENDOR (METAX) implementations. | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import functools | ||||||
|
|
||||||
| from vllm_fl.dispatch.types import OpImpl, BackendImplKind, BackendPriority | ||||||
|
|
||||||
|
|
||||||
| def _bind_is_available(fn, is_available_fn): | ||||||
| """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" | ||||||
|
|
||||||
| @functools.wraps(fn) | ||||||
| def wrapper(*args, **kwargs): | ||||||
| return fn(*args, **kwargs) | ||||||
|
|
||||||
| wrapper._is_available = is_available_fn | ||||||
| return wrapper | ||||||
|
|
||||||
|
|
||||||
| def register_builtins(registry) -> None: | ||||||
| """ | ||||||
| Register all METAX (VENDOR) operator implementations. | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| Args: | ||||||
| registry: Registry to register into | ||||||
| """ | ||||||
| from .txda import TxdaBackend | ||||||
|
|
||||||
| backend = TxdaBackend() | ||||||
| is_avail = backend.is_available | ||||||
|
|
||||||
| impls = [ | ||||||
| # # Activation | ||||||
| # OpImpl( | ||||||
| # op_name="silu_and_mul", | ||||||
| # impl_id="vendor.txda", | ||||||
| # kind=BackendImplKind.VENDOR, | ||||||
| # fn=_bind_is_available(backend.silu_and_mul, is_avail), | ||||||
| # vendor="txda", | ||||||
| # priority=BackendPriority.VENDOR, | ||||||
| # ), | ||||||
| # # Normalization | ||||||
| # OpImpl( | ||||||
| # op_name="rms_norm", | ||||||
| # impl_id="vendor.txda", | ||||||
| # kind=BackendImplKind.VENDOR, | ||||||
| # fn=_bind_is_available(backend.rms_norm, is_avail), | ||||||
| # vendor="txda", | ||||||
| # priority=BackendPriority.VENDOR, | ||||||
| # ), | ||||||
| # # Rotary Embedding | ||||||
| # OpImpl( | ||||||
| # op_name="rotary_embedding", | ||||||
| # impl_id="vendor.txda", | ||||||
| # kind=BackendImplKind.VENDOR, | ||||||
| # fn=_bind_is_available(backend.rotary_embedding, is_avail), | ||||||
| # vendor="txda", | ||||||
| # priority=BackendPriority.VENDOR, | ||||||
| # ), | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we dropping the other implementations? |
||||||
| # Attention Backend | ||||||
| OpImpl( | ||||||
| op_name="attention_backend", | ||||||
| impl_id="vendor.txda", | ||||||
| kind=BackendImplKind.VENDOR, | ||||||
| fn=_bind_is_available(backend.attention_backend, is_avail), | ||||||
| vendor="txda", | ||||||
| priority=BackendPriority.VENDOR, | ||||||
| ), | ||||||
| ] | ||||||
|
|
||||||
| registry.register_many(impls) | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,150 @@ | ||||||
| # Copyright (c) 2026 BAAI. All rights reserved. | ||||||
|
|
||||||
| """ | ||||||
| Txda backend implementation. | ||||||
|
|
||||||
| This backend provides operator implementations for Tsingmiocro Txda NPUs. | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from typing import Optional, Union | ||||||
Check noticeCode scanning / CodeQL Unused import Note library
Import of 'Union' is not used.
|
||||||
|
|
||||||
| import torch | ||||||
|
|
||||||
| # from vllm_fl.dispatch.backends.flaggems import FlagGemsBackend | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this line. |
||||||
| from vllm_fl.dispatch.backends.base import Backend | ||||||
|
|
||||||
|
|
||||||
| class TxdaBackend(Backend): | ||||||
| """ | ||||||
| Txda backend for operator implementations. | ||||||
|
|
||||||
| This backend uses Txda CANN libraries to provide high-performance | ||||||
| operator implementations for Tsingmiocro Txda NPUs. | ||||||
| """ | ||||||
|
|
||||||
| _available: Optional[bool] = None | ||||||
|
|
||||||
| @property | ||||||
| def name(self) -> str: | ||||||
| return "txda" | ||||||
|
|
||||||
| @property | ||||||
| def vendor(self) -> Optional[str]: | ||||||
| return "txda" | ||||||
|
|
||||||
| def is_available(self) -> bool: | ||||||
| """Check if Txda hardware and libraries are available.""" | ||||||
| if TxdaBackend._available is None: | ||||||
| try: | ||||||
| # Check for torch_npu (Txda PyTorch extension) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| import torch_txda | ||||||
|
|
||||||
| # Check if NPU device is available | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if torch.txda.is_available() and torch.txda.device_count() > 0: | ||||||
| TxdaBackend._available = True | ||||||
| else: | ||||||
| TxdaBackend._available = False | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is weird that we change class properties in an instance method. |
||||||
| except (ImportError, AttributeError): | ||||||
| TxdaBackend._available = False | ||||||
| return TxdaBackend._available | ||||||
|
|
||||||
| # ==================== Operator Implementations ==================== | ||||||
|
|
||||||
| # def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: | ||||||
| # """ | ||||||
| # SiLU activation followed by element-wise multiplication. | ||||||
|
|
||||||
| # Args: | ||||||
| # obj: The calling obj (for interface consistency) | ||||||
| # x: Input tensor of shape [..., 2*d] | ||||||
|
|
||||||
| # Returns: | ||||||
| # Output tensor of shape [..., d] | ||||||
| # """ | ||||||
| # from .impl.activation import silu_and_mul_Txda | ||||||
|
|
||||||
| # return silu_and_mul_Txda(obj, x) | ||||||
|
|
||||||
| # def rms_norm( | ||||||
| # self, | ||||||
| # obj, | ||||||
| # x: torch.Tensor, | ||||||
| # residual: Optional[torch.Tensor] = None, | ||||||
| # ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | ||||||
| # """ | ||||||
| # RMS normalization. | ||||||
|
|
||||||
| # Args: | ||||||
| # obj: The calling obj (e.g., RMSNorm layer) | ||||||
| # x: Input tensor | ||||||
| # residual: Optional residual tensor | ||||||
|
|
||||||
| # Returns: | ||||||
| # Normalized tensor, or tuple of (normalized, residual) if residual is provided | ||||||
| # """ | ||||||
| # from .impl.normalization import rms_norm_Txda | ||||||
|
|
||||||
| # return rms_norm_Txda(obj, x, residual) | ||||||
|
|
||||||
| # def rotary_embedding( | ||||||
| # self, | ||||||
| # obj, | ||||||
| # query: torch.Tensor, | ||||||
| # key: torch.Tensor, | ||||||
| # cos: torch.Tensor, | ||||||
| # sin: torch.Tensor, | ||||||
| # position_ids: torch.Tensor, | ||||||
| # rotary_interleaved: bool = False, | ||||||
| # inplace: bool = True, | ||||||
| # ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||
| # """ | ||||||
| # Apply rotary position embedding. | ||||||
|
|
||||||
| # Args: | ||||||
| # obj: The calling obj (for interface consistency) | ||||||
| # query: Query tensor | ||||||
| # key: Key tensor | ||||||
| # cos: Cosine cache | ||||||
| # sin: Sine cache | ||||||
| # position_ids: Position indices | ||||||
| # rotary_interleaved: Whether to use interleaved rotary | ||||||
| # inplace: Whether to modify tensors in-place | ||||||
|
|
||||||
| # Returns: | ||||||
| # Tuple of (embedded_query, embedded_key) | ||||||
| # """ | ||||||
| # from .impl.rotary import rotary_embedding_Txda | ||||||
|
|
||||||
| # return rotary_embedding_Txda( | ||||||
| # obj, | ||||||
| # query, | ||||||
| # key, | ||||||
| # cos, | ||||||
| # sin, | ||||||
| # position_ids, | ||||||
| # rotary_interleaved=rotary_interleaved, | ||||||
| # inplace=inplace, | ||||||
| # ) | ||||||
|
|
||||||
| def attention_backend(self, use_mla: bool = False) -> str: | ||||||
| """ | ||||||
| Get the attention backend class path for Txda NPU. | ||||||
|
|
||||||
| This method returns the native Txda attention backend that uses | ||||||
| torch_npu operators (npu_fused_infer_attention_score, etc.) | ||||||
| instead of flag_gems operators. | ||||||
|
|
||||||
| Uses vllm_fl's native Txda implementation which directly calls | ||||||
| torch_npu operators without depending on vllm-Txda package. | ||||||
|
|
||||||
| Args: | ||||||
| use_mla: Whether to use Multi-head Latent Attention (MLA) | ||||||
|
|
||||||
| Returns: | ||||||
| Fully qualified class path string | ||||||
| """ | ||||||
| if use_mla: | ||||||
| return "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend" | ||||||
| return "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,7 +72,8 @@ | |
| ### TODO(lms): simplify it | ||
| if library_path is None: | ||
| flagcx_path = os.getenv('FLAGCX_PATH') | ||
| library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so") | ||
| #library_path=os.path.join(flagcx_path, "libflagcx.so") # rcy fix | ||
| library_path= "/usr/local/kuiper/lib/libflagcx.so" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are using hard-coded path rather than using the environment variable? |
||
| self.flagcx = FLAGCXLibrary(library_path) | ||
| else: | ||
| self.flagcx = FLAGCXLibrary(library_path) | ||
|
|
@@ -113,7 +114,8 @@ | |
| # nccl communicator and stream will use this device | ||
| # `torch.cuda.device` is a context manager that changes the | ||
| # current cuda device to the specified one | ||
| with torch.cuda.device(device): | ||
| # | ||
| with torch.txda.device(device): | ||
| self.comm = self.flagcx.flagcxCommInitRank( | ||
| self.world_size, ctypes.byref(self.unique_id), self.rank) | ||
|
|
||
|
|
@@ -144,13 +146,23 @@ | |
| if stream is None: | ||
| stream = current_stream() | ||
| flagcx_stream = self.flagcx.adaptor_stream_copy(stream) | ||
| change_type = False | ||
| if in_tensor.dtype == torch.bfloat16: | ||
| in_tensor = in_tensor.to(torch.float32) | ||
| out_tensor = out_tensor.to(torch.float32) | ||
| change_type = True | ||
|
|
||
| self.flagcx.flagcxAllReduce(buffer_type(in_tensor.data_ptr()), | ||
| buffer_type(out_tensor.data_ptr()), | ||
| in_tensor.numel(), | ||
| flagcxDataTypeEnum.from_torch(in_tensor.dtype), | ||
| flagcxRedOpTypeEnum.from_torch(op), self.comm, | ||
| flagcx_stream) | ||
| self.flagcx.adaptor_stream_free(flagcx_stream) | ||
| if change_type: | ||
| in_tensor = in_tensor.to(torch.bfloat16) | ||
Check noticeCode scanning / CodeQL Unused local variable Note
Variable in_tensor is not used.
|
||
| out_tensor = out_tensor.to(torch.bfloat16) | ||
|
|
||
| return out_tensor | ||
|
|
||
| def all_gather(self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TxdaBackend typo