diff --git a/transformer_engine/plugin/core/backends/vendor/tsingmicro/__init__.py b/transformer_engine/plugin/core/backends/vendor/tsingmicro/__init__.py new file mode 100644 index 0000000000..52c998c87f --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/tsingmicro/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .tsingmicro import TXDABackend + +__all__ = ["TXDABackend"] diff --git a/transformer_engine/plugin/core/backends/vendor/tsingmicro/register_ops.py b/transformer_engine/plugin/core/backends/vendor/tsingmicro/register_ops.py new file mode 100644 index 0000000000..a9ba99bc4c --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/tsingmicro/register_ops.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +TXDA backend operator registrations. + +This module registers all TXDA PyTorch implementations. +""" + +from __future__ import annotations + +import functools + +from transformer_engine.plugin.core.types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all TXDA PyTorch operator implementations. + + Args: + registry: Registry to register into + """ + from .tsingmicro import TXDABackend + + # Create a backend instance to access the methods + backend = TXDABackend() + + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.txda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="txda", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/backends/vendor/tsingmicro/tsingmicro.py b/transformer_engine/plugin/core/backends/vendor/tsingmicro/tsingmicro.py new file mode 100644 index 0000000000..041247939b --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/tsingmicro/tsingmicro.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from typing import Any, Dict, List, Optional, Tuple, Union +import torch +from ....ops import * + + +def _ensure_txda_available(): + global _txda_available + try: + import torch_txda + + return True + except Exception as e: + return False + + +def _check_txda_available() -> bool: + if _ensure_txda_available(): + return True + else: + return False + + +class TXDABackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_txda_available() + + def is_available(self) -> bool: + return _check_txda_available() + + def get_flash_attention_class(self): + raise NotImplementedError("get_flash_attention_class - not implemented in txda backend")