From 506b94777bd62f0fba4119287c1c67f1e7255ff4 Mon Sep 17 00:00:00 2001 From: zhangpeiyang <583932508@qq.com> Date: Fri, 3 Apr 2026 15:20:22 +0800 Subject: [PATCH 1/3] fix_muti_backend_bugs --- src/flag_gems/runtime/backend/__init__.py | 229 ++++++++++---------- src/flag_gems/runtime/backend/device.py | 144 ++++-------- src/flag_gems/runtime/common.py | 35 +++ src/flag_gems/runtime/configloader.py | 92 ++++---- src/flag_gems/utils/codegen_config_utils.py | 6 +- src/flag_gems/utils/libentry.py | 8 +- 6 files changed, 235 insertions(+), 279 deletions(-) mode change 100644 => 100755 src/flag_gems/runtime/backend/device.py mode change 100644 => 100755 src/flag_gems/runtime/common.py mode change 100644 => 100755 src/flag_gems/runtime/configloader.py mode change 100644 => 100755 src/flag_gems/utils/codegen_config_utils.py mode change 100644 => 100755 src/flag_gems/utils/libentry.py diff --git a/src/flag_gems/runtime/backend/__init__.py b/src/flag_gems/runtime/backend/__init__.py index 0fd80207d1..0ce06ff7f6 100644 --- a/src/flag_gems/runtime/backend/__init__.py +++ b/src/flag_gems/runtime/backend/__init__.py @@ -9,17 +9,37 @@ from ..common import vendors from . import backend_utils -vendor_module = None -device_name = None -torch_device_object = None -torch_device_fn_device = None -tl_extra_backend_module = None -ops_module = None -fused_module = None -heuristic_config_module = None -vendor_extra_lib_imported = False -device_fn_cache = {} -customized_ops = None + +class BackendState: + """Singleton class to manage backend state variables.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._initialized = True + self.vendor_module = None + self.device_name = None + self.torch_device_object = None + self.torch_device_fn_device = None + self.tl_extra_backend_module = None + self.ops_module = None + self.fused_module = None + self.heuristic_config_module = None + self.vendor_extra_lib_imported = False + self.device_fn_cache = {} + self.customized_ops = None + + +# Global singleton instance +_state = BackendState() class BackendArchEvent: @@ -51,32 +71,29 @@ def get_functions_from_module(self, module): return inspect.getmembers(module, inspect.isfunction) if module else [] def get_heuristics_configs(self): - heuristic_module = None try: heuristic_module = self.arch_module - except Exception: # noqa E722 + except Exception: sys.path.insert(0, str(self.current_arch_path)) heuristic_module = importlib.import_module("heuristics_config_utils") sys.path.remove(str(self.current_arch_path)) - if hasattr(heuristic_module, "HEURISTICS_CONFIGS"): - return heuristic_module.HEURISTICS_CONFIGS - return None + return getattr(heuristic_module, "HEURISTICS_CONFIGS", None) def get_autotune_configs(self): path = self.current_arch_path return backend_utils.get_tune_config(file_path=path) def get_arch(self, device=0): - if not hasattr(vendor_module, "ARCH_MAP"): + if not hasattr(_state.vendor_module, "ARCH_MAP"): return - arch_map = vendor_module.ARCH_MAP + arch_map = _state.vendor_module.ARCH_MAP arch_string = os.environ.get("ARCH", "") arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string if not arch_string_num: try: - if not torch_device_object.is_available(): + if not _state.torch_device_object.is_available(): return False - props = torch_device_object.get_device_properties(device) + props = _state.torch_device_object.get_device_properties(device) arch_string_num = str(props.major) except Exception: self.has_arch = False @@ -89,16 +106,14 @@ def get_arch(self, device=0): return arch_map[arch_string_num] def _get_supported_archs(self, path=None): - path = path or vendor_module.__path__[0] - excluded = ("ops", "fused") - path = Path(path) + path = Path(path or _state.vendor_module.__path__[0]) path = path.parent if path.is_file() else path - archs = {} - for p in path.iterdir(): - name = str(p).split("/")[-1] - if p.is_dir() and name not in excluded and not name.startswith("_"): - archs.update({name: str(p)}) - return archs + excluded = ("ops", "fused") + return { + p.name: str(p) + for p in path.iterdir() + if p.is_dir() and p.name not in excluded and not p.name.startswith("_") + } def get_supported_archs(self): return list(self.supported_archs.keys()) @@ -113,51 +128,37 @@ def get_arch_module(self): def get_arch_ops(self): arch_specialized_ops = [] - modules = [] sys.path.append(self.current_arch_path) - ops_module = importlib.import_module(f"{self.arch}.ops") try: - ops_module = self.arch_module.ops - modules.append(ops_module) - except Exception: - try: - sys.path.append(self.current_arch_path) + ops_module = getattr(self.arch_module, "ops", None) + if ops_module is None: ops_module = importlib.import_module(f"{self.arch}.ops") - modules.append(ops_module) - except Exception as err_msg: - self.error_msgs.append(err_msg) - - for mod in modules: - arch_specialized_ops.extend(self.get_functions_from_module(mod)) - + arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) + except Exception as err_msg: + self.error_msgs.append(err_msg) return arch_specialized_ops -def import_vendor_extra_lib(vendor_name=None): - global vendor_extra_lib_imported - if vendor_extra_lib_imported is True: - return - global ops_module, fused_module +def _import_module_safe(module_name, vendor_name, module_type): + """Helper to import a module with proper error handling.""" try: - ops_module = importlib.import_module(f"_{vendor_name}.ops") + return importlib.import_module(module_name) except ModuleNotFoundError: print( - f"[Note] No specialized common operators were found in" - f"the {vendor_name} implementation, and general common operators are used by default." + f"[Note] No specialized {module_type} operators were found in " + f"the {vendor_name} implementation, and general {module_type} operators are used by default." ) + return None except Exception as e: raise RuntimeError(f"Import vendor extra lib failed: {e}") - try: - fused_module = importlib.import_module(f"_{vendor_name}.fused") - except ModuleNotFoundError: - print( - f"[Note] No specialized fused operators were found in" - f"the {vendor_name} implementation, and general fused operators are used by default." - ) - except Exception as e: - raise RuntimeError(f"Import vendor extra lib failed: {e}") - vendor_extra_lib_imported = True + +def import_vendor_extra_lib(vendor_name=None): + if _state.vendor_extra_lib_imported: + return + _state.ops_module = _import_module_safe(f"_{vendor_name}.ops", vendor_name, "common") + _state.fused_module = _import_module_safe(f"_{vendor_name}.fused", vendor_name, "fused") + _state.vendor_extra_lib_imported = True def get_codegen_result(code, result_key): @@ -172,8 +173,7 @@ def get_codegen_result(code, result_key): @functools.lru_cache(maxsize=32) def gen_torch_tensor_attr_res(tensor, attr_name): - global device_name - device_name = device_name or get_vendor_info().device_name + _state.device_name = _state.device_name or get_vendor_info().device_name code = f""" import torch res = {tensor}.{attr_name} @@ -182,43 +182,40 @@ def gen_torch_tensor_attr_res(tensor, attr_name): def set_tl_extra_backend_module(vendor_name=None): - global device_name, tl_extra_backend_module vendor_info = get_vendor_info(vendor_name) - device_name = device_name or vendor_info.device_name - extra_name = vendor_info.triton_extra_name or device_name + _state.device_name = _state.device_name or vendor_info.device_name + extra_name = vendor_info.triton_extra_name or _state.device_name module_str = f"triton.language.extra.{extra_name}.libdevice" - tl_extra_backend_module = importlib.import_module(module_str) + _state.tl_extra_backend_module = importlib.import_module(module_str) def get_tl_extra_backend_module(): - return tl_extra_backend_module + return _state.tl_extra_backend_module def set_torch_backend_device_fn(vendor_name=None): - global device_name, torch_device_fn_device - device_name = device_name or get_vendor_info(vendor_name).device_name - module_str = f"torch.backends.{device_name}" - if device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): - torch_device_fn_device = None + _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name + module_str = f"torch.backends.{_state.device_name}" + if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): + _state.torch_device_fn_device = None else: - torch_device_fn_device = importlib.import_module(module_str) + _state.torch_device_fn_device = importlib.import_module(module_str) def get_torch_backend_device_fn(): - return torch_device_fn_device + return _state.torch_device_fn_device def gen_torch_device_object(vendor_name=None): - global device_name, torch_device_object - if torch_device_object is not None: - return torch_device_object - device_name = device_name or get_vendor_info(vendor_name).device_name + if _state.torch_device_object is not None: + return _state.torch_device_object + _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name code = f""" import torch -fn = torch.{device_name} +fn = torch.{_state.device_name} """ - torch_device_object = get_codegen_result(code, "fn") - return torch_device_object + _state.torch_device_object = get_codegen_result(code, "fn") + return _state.torch_device_object def get_vendor_module(vendor_name, query=False): @@ -233,73 +230,67 @@ def get_module(vendor_name): ): # The purpose of a query is to provide the user with the instance that he wants to import return get_module(vendor_name) - global vendor_module - if vendor_module is None: - vendor_module = get_module("_" + vendor_name) - return vendor_module + if _state.vendor_module is None: + _state.vendor_module = get_module("_" + vendor_name) + return _state.vendor_module def get_vendor_info(vendor_name=None, query=False): if query: return get_vendor_module(vendor_name, query).vendor_info - global vendor_module # noqa: F824 get_vendor_module(vendor_name) - return vendor_module.vendor_info + return _state.vendor_module.vendor_info def get_vendor_infos(): infos = [] for vendor_name in vendors.get_all_vendors(): - vendor_name = "_" + vendor_name try: - single_info = get_vendor_info(vendor_name, query=True) - infos.append(single_info) + infos.append(get_vendor_info(f"_{vendor_name}", query=True)) except Exception: - pass - + continue return infos def get_current_device_extend_op(vendor_name=None): import_vendor_extra_lib(vendor_name) - global customized_ops - if customized_ops is not None: - return customized_ops - customized_ops = [] - if ops_module is not None: - ops = inspect.getmembers(ops_module, inspect.isfunction) - customized_ops += ops - if fused_module is not None: - fused_ops = inspect.getmembers(fused_module, inspect.isfunction) - customized_ops += fused_ops - return customized_ops + if _state.customized_ops is not None: + return _state.customized_ops + _state.customized_ops = [] + if _state.ops_module is not None: + ops = inspect.getmembers(_state.ops_module, inspect.isfunction) + _state.customized_ops += ops + if _state.fused_module is not None: + fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction) + _state.customized_ops += fused_ops + return _state.customized_ops def get_curent_device_unused_op(vendor_name=None): - global vendor_module # noqa: F824 get_vendor_module(vendor_name) - return list(vendor_module.CUSTOMIZED_UNUSED_OPS) + return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS) def get_heuristic_config(vendor_name=None): - global heuristic_config_module - try: - heuristic_config_module = importlib.import_module( - f"_{vendor_name}.heuristics_config_utils" - ) - except: # noqa E722 - heuristic_config_module = importlib.import_module( - "_nvidia.heuristics_config_utils" - ) - if hasattr(heuristic_config_module, "HEURISTICS_CONFIGS"): - return heuristic_config_module.HEURISTICS_CONFIGS - return None + config_name = "heuristics_config_utils" + default_backend = "nvidia" + for backend in (vendor_name, default_backend): + mod_name = f"_{backend}.{config_name}" + try: + _state.heuristic_config_module = importlib.import_module(mod_name) + except Exception: + continue + return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None) def get_tune_config(vendor_name=None): - global vendor_module # noqa: F824 get_vendor_module(vendor_name) return backend_utils.get_tune_config(vendor_name) +def get_backend_state() -> BackendState: + """Get the global BackendState singleton instance.""" + return _state + + __all__ = ["*"] diff --git a/src/flag_gems/runtime/backend/device.py b/src/flag_gems/runtime/backend/device.py old mode 100644 new mode 100755 index e186554e94..730cdd90d8 --- a/src/flag_gems/runtime/backend/device.py +++ b/src/flag_gems/runtime/backend/device.py @@ -1,148 +1,92 @@ import os import shlex import subprocess -import threading -from queue import Queue +from concurrent.futures import ThreadPoolExecutor, as_completed import torch # noqa: F401 from .. import backend, error -from ..common import vendors +from ..common import vendors, UNSUPPORT_FP64, UNSUPPORT_BF16, UNSUPPORT_INT64, _VENDOR_TORCH_ATTR -UNSUPPORT_FP64 = [ - vendors.CAMBRICON, - vendors.ILUVATAR, - vendors.KUNLUNXIN, - vendors.MTHREADS, - vendors.AIPU, - vendors.ASCEND, - vendors.TSINGMICRO, - vendors.SUNRISE, - vendors.ENFLAME, -] -UNSUPPORT_BF16 = [ - vendors.AIPU, - vendors.SUNRISE, -] -UNSUPPORT_INT64 = [ - vendors.AIPU, - vendors.TSINGMICRO, - vendors.SUNRISE, - vendors.ENFLAME, -] +class DeviceDetector: + """Singleton class to manage device context.""" - -# A singleton class to manage device context. -class DeviceDetector(object): _instance = None def __new__(cls, *args, **kargs): if cls._instance is None: - cls._instance = super(DeviceDetector, cls).__new__(cls) + cls._instance = super().__new__(cls) return cls._instance def __init__(self, vendor_name=None): - if not hasattr(self, "initialized"): - self.initialized = True - # A list of all available vendor names. - self.vendor_list = vendors.get_all_vendors().keys() - - # A dataclass instance, get the vendor information based on the provided or default vendor name. - self.info = self.get_vendor(vendor_name) - - # vendor_name is like 'nvidia', device_name is like 'cuda'. - self.vendor_name = self.info.vendor_name - self.name = self.info.device_name - self.vendor = vendors.get_all_vendors()[self.vendor_name] - self.dispatch_key = ( - self.name.upper() - if self.info.dispatch_key is None - else self.info.dispatch_key - ) - self.device_count = backend.gen_torch_device_object( - self.vendor_name - ).device_count() - self.support_fp64 = self.vendor not in UNSUPPORT_FP64 - self.support_bf16 = self.vendor not in UNSUPPORT_BF16 - self.support_int64 = self.vendor not in UNSUPPORT_INT64 - - def get_vendor(self, vendor_name=None) -> tuple: - # Try to get the vendor name from a quick special command like 'torch.mlu'. + if hasattr(self, "initialized"): + return + self.initialized = True + self.vendor_list = vendors.get_all_vendors().keys() + self.info = self.get_vendor(vendor_name) + self.vendor_name = self.info.vendor_name + self.name = self.info.device_name + self.vendor = vendors.get_all_vendors()[self.vendor_name] + self.dispatch_key = self.info.dispatch_key or self.name.upper() + self.device_count = backend.gen_torch_device_object(self.vendor_name).device_count() + self.support_fp64 = self.vendor not in UNSUPPORT_FP64 + self.support_bf16 = self.vendor not in UNSUPPORT_BF16 + self.support_int64 = self.vendor not in UNSUPPORT_INT64 + + def get_vendor(self, vendor_name=None): + # Try environment variable first vendor_from_env = self._get_vendor_from_env() - if vendor_from_env is not None: + if vendor_from_env: return backend.get_vendor_info(vendor_from_env) - + # Try quick torch attribute detection vendor_name = self._get_vendor_from_quick_cmd() - if vendor_name is not None: + if vendor_name: return backend.get_vendor_info(vendor_name) + # Fall back to system command detection try: - # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented. return self._get_vendor_from_lib() except Exception: return self._get_vendor_from_sys() def _get_vendor_from_quick_cmd(self): - cmd = { - "cambricon": "mlu", - "mthreads": "musa", - "iluvatar": "corex", - "ascend": "npu", - "sunrise": "ptpu", - "enflame": "gcu", - } - for vendor_name, flag in cmd.items(): - if hasattr(torch, flag): + for vendor_name, attr in _VENDOR_TORCH_ATTR.items(): + if hasattr(torch, attr): return vendor_name try: import torch_npu - - for vendor_name, flag in cmd.items(): - if hasattr(torch_npu, flag): + for vendor_name, attr in _VENDOR_TORCH_ATTR.items(): + if hasattr(torch_npu, attr): return vendor_name - except: # noqa: E722 + except ImportError: pass return None def _get_vendor_from_env(self): - device_from_evn = os.environ.get("GEMS_VENDOR") - return None if device_from_evn not in self.vendor_list else device_from_evn + vendor = os.environ.get("GEMS_VENDOR") + return vendor if vendor in self.vendor_list else None def _get_vendor_from_sys(self): vendor_infos = backend.get_vendor_infos() - result_single_info = Queue() - def runcmd(single_info): - device_query_cmd = single_info.device_query_cmd + def check_vendor(info): try: - cmd_args = shlex.split(device_query_cmd) + cmd_args = shlex.split(info.device_query_cmd) result = subprocess.run(cmd_args, capture_output=True, text=True) - if result.returncode == 0: - result_single_info.put(single_info) - except: # noqa: E722 - pass + return info if result.returncode == 0 else None + except Exception: + return None - threads = [] - for single_info in vendor_infos: - # Get the vendor information by running system commands. - thread = threading.Thread(target=runcmd, args=(single_info,)) - threads.append(thread) - thread.start() + with ThreadPoolExecutor() as executor: + futures = {executor.submit(check_vendor, info): info for info in vendor_infos} + for future in as_completed(futures): + result = future.result() + if result: + return result - for thread in threads: - thread.join() - if result_single_info.empty(): - error.device_not_found() - else: - return result_single_info.get() + error.device_not_found() def get_vendor_name(self): return self.vendor_name def _get_vendor_from_lib(self): - # Reserve the associated interface for triton or torch - # although they are not implemented yet. - # try: - # return triton.get_vendor_info() - # except Exception: - # return torch.get_vendor_info() raise RuntimeError("The method is not implemented") diff --git a/src/flag_gems/runtime/common.py b/src/flag_gems/runtime/common.py old mode 100644 new mode 100755 index 384b4fb19b..e57f715c39 --- a/src/flag_gems/runtime/common.py +++ b/src/flag_gems/runtime/common.py @@ -22,3 +22,38 @@ def get_all_vendors(cls) -> dict: for member in cls: vendorDict[member.name.lower()] = member return vendorDict + +UNSUPPORT_FP64 = frozenset({ + vendors.CAMBRICON, + vendors.ILUVATAR, + vendors.KUNLUNXIN, + vendors.MTHREADS, + vendors.AIPU, + vendors.ASCEND, + vendors.TSINGMICRO, + vendors.SUNRISE, + vendors.ENFLAME, +}) + +UNSUPPORT_BF16 = frozenset({ + vendors.AIPU, + vendors.SUNRISE, +}) + +UNSUPPORT_INT64 = frozenset({ + vendors.AIPU, + vendors.TSINGMICRO, + vendors.SUNRISE, + vendors.ENFLAME, +}) + + +# Mapping from vendor name to torch attribute for quick detection +_VENDOR_TORCH_ATTR = { + "cambricon": "mlu", + "mthreads": "musa", + "iluvatar": "corex", + "ascend": "npu", + "sunrise": "ptpu", + "enflame": "gcu", +} \ No newline at end of file diff --git a/src/flag_gems/runtime/configloader.py b/src/flag_gems/runtime/configloader.py old mode 100644 new mode 100755 index 338ddad9fc..e69d262142 --- a/src/flag_gems/runtime/configloader.py +++ b/src/flag_gems/runtime/configloader.py @@ -27,16 +27,7 @@ def __init__(self): self.default_primitive_yaml_config = self.get_default_tune_config() self.vendor_heuristics_config = self.get_vendor_heuristics_config() self.default_heuristics_config = self.get_default_heuristics_config() - try: - if backend.BackendArchEvent().has_arch: - self.arch_specialized_yaml_config = ( - backend.BackendArchEvent().autotune_configs - ) - self.arch_heuristics_config = ( - backend.BackendArchEvent().heuristics_configs - ) - except Exception as err: - print(f"[INFO] : {err}") + self.update_config_from_arch() if self.vendor_heuristics_config is None: vendorname = self.device.vendor_name @@ -47,20 +38,20 @@ def __init__(self): self.gen_key = "gen" # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config self.loaded_triton_config = {} - self.triton_config_default = { - "num_stages": 2, - "num_warps": 4, - "num_ctas": 1, - } - if self.device.vendor_name in ["hygon"]: - self.triton_config_default = { - "num_stages": 2, - "num_warps": 4, - "num_ctas": 1, - "num_ldmatrixes": 0, - } + self.triton_config_default = {"num_stages": 2, "num_warps": 4, "num_ctas": 1} + if self.device.vendor_name == "hygon": + self.triton_config_default["num_ldmatrixes"] = 0 self.load_all() + def update_config_from_arch(self): + try: + archEvent = backend.BackendArchEvent() + if archEvent.has_arch: + self.arch_specialized_yaml_config = (archEvent.autotune_configs) + self.arch_heuristics_config = (archEvent.heuristics_configs) + except Exception as err: + print(f"[INFO] : {err}") + def load_all(self): for key in self.vendor_primitive_yaml_config: self.loaded_triton_config[key] = self.get_tuned_config(key) @@ -172,24 +163,37 @@ def to_gen_config(self, gen_config): current_config, ) + def _get_op_configs(self, op_name): + """Get config for op_name from available config sources.""" + for config in ( + self.arch_specialized_yaml_config, + self.vendor_primitive_yaml_config, + self.default_primitive_yaml_config, + ): + if config and op_name in config: + return config[op_name] + return [] + + def _create_triton_config(self, single_config, current_config): + """Create a triton.Config with appropriate parameters.""" + kwargs = { + "num_warps": current_config["num_warps"], + "num_stages": current_config["num_stages"], + "num_ctas": current_config["num_ctas"], + } + if self.device.vendor_name == "hygon": + kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"] + return triton.Config(single_config["META"], **kwargs) + def get_tuned_config(self, op_name): if op_name in self.loaded_triton_config: return self.loaded_triton_config[op_name] - if ( - self.arch_specialized_yaml_config - and op_name in self.arch_specialized_yaml_config - ): - current_op_configs = self.arch_specialized_yaml_config[op_name] - elif op_name in self.vendor_primitive_yaml_config: - current_op_configs = self.vendor_primitive_yaml_config[op_name] - else: - current_op_configs = self.default_primitive_yaml_config[op_name] + current_op_configs = self._get_op_configs(op_name) + if not current_op_configs: + return [] configs = [] - if len(current_op_configs) == 0: - return configs - for single_config in current_op_configs: if self.gen_key in single_config: configs.extend(self.to_gen_config(single_config)) @@ -200,23 +204,5 @@ def get_tuned_config(self, op_name): if default_param in single_config: current_config[default_param] = single_config[default_param] - if self.device.vendor_name in ["hygon"]: - configs.append( - triton.Config( - single_config["META"], - num_warps=current_config["num_warps"], - num_stages=current_config["num_stages"], - num_ctas=current_config["num_ctas"], - num_ldmatrixes=current_config["num_ldmatrixes"], - ) - ) - else: - configs.append( - triton.Config( - single_config["META"], - num_warps=current_config["num_warps"], - num_stages=current_config["num_stages"], - num_ctas=current_config["num_ctas"], - ) - ) + configs.append(self._create_triton_config(single_config, current_config)) return configs diff --git a/src/flag_gems/utils/codegen_config_utils.py b/src/flag_gems/utils/codegen_config_utils.py old mode 100644 new mode 100755 index 2ac5858813..f9345486d2 --- a/src/flag_gems/utils/codegen_config_utils.py +++ b/src/flag_gems/utils/codegen_config_utils.py @@ -4,7 +4,7 @@ import triton from flag_gems.runtime import device -from flag_gems.runtime.backend import vendor_module +from flag_gems.runtime.backend import _state from flag_gems.runtime.common import vendors @@ -85,12 +85,12 @@ def __post_init__(self): vendors.CAMBRICON: ( CodeGenConfig( 8192, - tuple([vendor_module.TOTAL_CORE_NUM, 1, 1]), + tuple([_state.vendor_module.TOTAL_CORE_NUM, 1, 1]), 32, True, prefer_1d_tile=int(triton.__version__[0]) < 3, ) - if vendor_module.vendor_info.vendor_name == "cambricon" + if _state.vendor_module.vendor_info.vendor_name == "cambricon" else None ), vendors.METAX: CodeGenConfig( diff --git a/src/flag_gems/utils/libentry.py b/src/flag_gems/utils/libentry.py old mode 100644 new mode 100755 index 8c6d8360a1..efd02a50c7 --- a/src/flag_gems/utils/libentry.py +++ b/src/flag_gems/utils/libentry.py @@ -30,7 +30,7 @@ from flag_gems import runtime from flag_gems.runtime import torch_device_fn -from flag_gems.runtime.backend import vendor_module +from flag_gems.runtime.backend import _state from flag_gems.utils.code_cache import config_cache_dir from flag_gems.utils.models import PersistantModel, SQLPersistantModel @@ -161,11 +161,11 @@ def __init__(self, db_url: Optional[str] = None): try: device_name: str = torch_device_fn.get_device_name().replace(" ", "_") except AttributeError: - device_name: str = vendor_module.vendor_info.device_name + device_name: str = _state.vendor_module.vendor_info.device_name cache_file_name: str = ( f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" - if vendor_module.vendor_info.vendor_name == "nvidia" - else f"TunedConfig_{vendor_module.vendor_info.vendor_name}_triton_{major_version}_{minor_version}.db" + if _state.vendor_module.vendor_info.vendor_name == "nvidia" + else f"TunedConfig_{_state.vendor_module.vendor_info.vendor_name}_triton_{major_version}_{minor_version}.db" ) cache_path: Path = config_cache_dir() / cache_file_name self.db_url: str = f"sqlite:///{cache_path}" From 225403477b50ec356977412a1f8e7a83f4c53e30 Mon Sep 17 00:00:00 2001 From: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> Date: Fri, 3 Apr 2026 16:06:41 +0800 Subject: [PATCH 2/3] Improve error handling and messages in init.py Refactor error messages for clarity and consistency. Signed-off-by: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> --- src/flag_gems/runtime/backend/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/flag_gems/runtime/backend/__init__.py b/src/flag_gems/runtime/backend/__init__.py index 0ce06ff7f6..8d1f927c20 100644 --- a/src/flag_gems/runtime/backend/__init__.py +++ b/src/flag_gems/runtime/backend/__init__.py @@ -129,8 +129,8 @@ def get_arch_module(self): def get_arch_ops(self): arch_specialized_ops = [] sys.path.append(self.current_arch_path) + ops_module = getattr(self.arch_module, "ops", None) try: - ops_module = getattr(self.arch_module, "ops", None) if ops_module is None: ops_module = importlib.import_module(f"{self.arch}.ops") arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) @@ -145,12 +145,12 @@ def _import_module_safe(module_name, vendor_name, module_type): return importlib.import_module(module_name) except ModuleNotFoundError: print( - f"[Note] No specialized {module_type} operators were found in " - f"the {vendor_name} implementation, and general {module_type} operators are used by default." + f"[Note] No specialized {module_type} operators were found for " + f"the {vendor_name}, generic {module_type} operators will be used by default." ) return None except Exception as e: - raise RuntimeError(f"Import vendor extra lib failed: {e}") + raise RuntimeError(f"Failed to import vendor extra lib: {e}") def import_vendor_extra_lib(vendor_name=None): From 6240b44a329ab82c550a61987653713150901af1 Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Fri, 3 Apr 2026 16:18:26 +0800 Subject: [PATCH 3/3] fix codestyle --- src/flag_gems/runtime/backend/__init__.py | 8 +++- src/flag_gems/runtime/backend/device.py | 18 ++++++-- src/flag_gems/runtime/common.py | 55 +++++++++++++---------- src/flag_gems/runtime/configloader.py | 10 +++-- src/flag_gems/utils/libentry.py | 8 ++-- 5 files changed, 64 insertions(+), 35 deletions(-) diff --git a/src/flag_gems/runtime/backend/__init__.py b/src/flag_gems/runtime/backend/__init__.py index 8d1f927c20..2da018300c 100644 --- a/src/flag_gems/runtime/backend/__init__.py +++ b/src/flag_gems/runtime/backend/__init__.py @@ -156,8 +156,12 @@ def _import_module_safe(module_name, vendor_name, module_type): def import_vendor_extra_lib(vendor_name=None): if _state.vendor_extra_lib_imported: return - _state.ops_module = _import_module_safe(f"_{vendor_name}.ops", vendor_name, "common") - _state.fused_module = _import_module_safe(f"_{vendor_name}.fused", vendor_name, "fused") + _state.ops_module = _import_module_safe( + f"_{vendor_name}.ops", vendor_name, "common" + ) + _state.fused_module = _import_module_safe( + f"_{vendor_name}.fused", vendor_name, "fused" + ) _state.vendor_extra_lib_imported = True diff --git a/src/flag_gems/runtime/backend/device.py b/src/flag_gems/runtime/backend/device.py index 730cdd90d8..ff3f0a1601 100755 --- a/src/flag_gems/runtime/backend/device.py +++ b/src/flag_gems/runtime/backend/device.py @@ -6,7 +6,14 @@ import torch # noqa: F401 from .. import backend, error -from ..common import vendors, UNSUPPORT_FP64, UNSUPPORT_BF16, UNSUPPORT_INT64, _VENDOR_TORCH_ATTR +from ..common import ( + _VENDOR_TORCH_ATTR, + UNSUPPORT_BF16, + UNSUPPORT_FP64, + UNSUPPORT_INT64, + vendors, +) + class DeviceDetector: """Singleton class to manage device context.""" @@ -28,7 +35,9 @@ def __init__(self, vendor_name=None): self.name = self.info.device_name self.vendor = vendors.get_all_vendors()[self.vendor_name] self.dispatch_key = self.info.dispatch_key or self.name.upper() - self.device_count = backend.gen_torch_device_object(self.vendor_name).device_count() + self.device_count = backend.gen_torch_device_object( + self.vendor_name + ).device_count() self.support_fp64 = self.vendor not in UNSUPPORT_FP64 self.support_bf16 = self.vendor not in UNSUPPORT_BF16 self.support_int64 = self.vendor not in UNSUPPORT_INT64 @@ -54,6 +63,7 @@ def _get_vendor_from_quick_cmd(self): return vendor_name try: import torch_npu + for vendor_name, attr in _VENDOR_TORCH_ATTR.items(): if hasattr(torch_npu, attr): return vendor_name @@ -77,7 +87,9 @@ def check_vendor(info): return None with ThreadPoolExecutor() as executor: - futures = {executor.submit(check_vendor, info): info for info in vendor_infos} + futures = { + executor.submit(check_vendor, info): info for info in vendor_infos + } for future in as_completed(futures): result = future.result() if result: diff --git a/src/flag_gems/runtime/common.py b/src/flag_gems/runtime/common.py index e57f715c39..5cd06e04e5 100755 --- a/src/flag_gems/runtime/common.py +++ b/src/flag_gems/runtime/common.py @@ -23,29 +23,36 @@ def get_all_vendors(cls) -> dict: vendorDict[member.name.lower()] = member return vendorDict -UNSUPPORT_FP64 = frozenset({ - vendors.CAMBRICON, - vendors.ILUVATAR, - vendors.KUNLUNXIN, - vendors.MTHREADS, - vendors.AIPU, - vendors.ASCEND, - vendors.TSINGMICRO, - vendors.SUNRISE, - vendors.ENFLAME, -}) - -UNSUPPORT_BF16 = frozenset({ - vendors.AIPU, - vendors.SUNRISE, -}) - -UNSUPPORT_INT64 = frozenset({ - vendors.AIPU, - vendors.TSINGMICRO, - vendors.SUNRISE, - vendors.ENFLAME, -}) + +UNSUPPORT_FP64 = frozenset( + { + vendors.CAMBRICON, + vendors.ILUVATAR, + vendors.KUNLUNXIN, + vendors.MTHREADS, + vendors.AIPU, + vendors.ASCEND, + vendors.TSINGMICRO, + vendors.SUNRISE, + vendors.ENFLAME, + } +) + +UNSUPPORT_BF16 = frozenset( + { + vendors.AIPU, + vendors.SUNRISE, + } +) + +UNSUPPORT_INT64 = frozenset( + { + vendors.AIPU, + vendors.TSINGMICRO, + vendors.SUNRISE, + vendors.ENFLAME, + } +) # Mapping from vendor name to torch attribute for quick detection @@ -56,4 +63,4 @@ def get_all_vendors(cls) -> dict: "ascend": "npu", "sunrise": "ptpu", "enflame": "gcu", -} \ No newline at end of file +} diff --git a/src/flag_gems/runtime/configloader.py b/src/flag_gems/runtime/configloader.py index e69d262142..6d34f2537f 100755 --- a/src/flag_gems/runtime/configloader.py +++ b/src/flag_gems/runtime/configloader.py @@ -38,7 +38,11 @@ def __init__(self): self.gen_key = "gen" # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config self.loaded_triton_config = {} - self.triton_config_default = {"num_stages": 2, "num_warps": 4, "num_ctas": 1} + self.triton_config_default = { + "num_stages": 2, + "num_warps": 4, + "num_ctas": 1, + } if self.device.vendor_name == "hygon": self.triton_config_default["num_ldmatrixes"] = 0 self.load_all() @@ -47,8 +51,8 @@ def update_config_from_arch(self): try: archEvent = backend.BackendArchEvent() if archEvent.has_arch: - self.arch_specialized_yaml_config = (archEvent.autotune_configs) - self.arch_heuristics_config = (archEvent.heuristics_configs) + self.arch_specialized_yaml_config = archEvent.autotune_configs + self.arch_heuristics_config = archEvent.heuristics_configs except Exception as err: print(f"[INFO] : {err}") diff --git a/src/flag_gems/utils/libentry.py b/src/flag_gems/utils/libentry.py index efd02a50c7..1e338f6209 100755 --- a/src/flag_gems/utils/libentry.py +++ b/src/flag_gems/utils/libentry.py @@ -157,15 +157,17 @@ def __new__(cls, *args, **kwargs): def __init__(self, db_url: Optional[str] = None): self.global_cache: Dict = {} self.volumn: Dict = {} + device_name = _state.vendor_module.vendor_info.device_name if db_url is None: try: device_name: str = torch_device_fn.get_device_name().replace(" ", "_") except AttributeError: - device_name: str = _state.vendor_module.vendor_info.device_name + device_name: str = device_name + cache_file_name: str = ( f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" - if _state.vendor_module.vendor_info.vendor_name == "nvidia" - else f"TunedConfig_{_state.vendor_module.vendor_info.vendor_name}_triton_{major_version}_{minor_version}.db" + if device_name == "nvidia" + else f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" ) cache_path: Path = config_cache_dir() / cache_file_name self.db_url: str = f"sqlite:///{cache_path}"