Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2025, BAAI. All rights reserved.
#
# See LICENSE for license information.

from .tsingmicro import TXDABackend

__all__ = ["TXDABackend"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2025, BAAI. All rights reserved.
#
# See LICENSE for license information.

"""
TXDA backend operator registrations.

This module registers all TXDA PyTorch implementations.
"""

from __future__ import annotations

import functools

from transformer_engine.plugin.core.types import OpImpl, BackendImplKind


def _bind_is_available(fn, is_available_fn):
"""Wrap a function and bind _is_available attribute for OpImpl.is_available() check."""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)

wrapper._is_available = is_available_fn
return wrapper


def register_builtins(registry) -> None:
"""
Register all TXDA PyTorch operator implementations.

Args:
registry: Registry to register into
"""
from .tsingmicro import TXDABackend

# Create a backend instance to access the methods
backend = TXDABackend()

if not backend.is_available():
return

# Bind is_available to all methods
is_avail = backend.is_available

impls = [
# FlashAttention class getter
OpImpl(
op_name="get_flash_attention_class",
impl_id="vendor.txda",
kind=BackendImplKind.VENDOR,
fn=_bind_is_available(backend.get_flash_attention_class, is_avail),
vendor="txda",
priority=100,
),
]

registry.register_many(impls)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2025, BAAI. All rights reserved.
#
# See LICENSE for license information.

import os
import subprocess
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from ....ops import *


def _ensure_txda_available():
global _txda_available
try:
import torch_txda

return True
except Exception as e:
return False


def _check_txda_available() -> bool:
if _ensure_txda_available():
return True
else:
return False


class TXDABackend(TEFLBackendBase):
@staticmethod
def check_available() -> bool:
return _check_txda_available()

def is_available(self) -> bool:
return _check_txda_available()

def get_flash_attention_class(self):
raise NotImplementedError("get_flash_attention_class - not implemented in txda backend")
Loading