-
Notifications
You must be signed in to change notification settings - Fork 309
[Bug Fix] Fix_muti_backend_bugs #2226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,17 +9,37 @@ | |||||
| from ..common import vendors | ||||||
| from . import backend_utils | ||||||
|
|
||||||
| vendor_module = None | ||||||
| device_name = None | ||||||
| torch_device_object = None | ||||||
| torch_device_fn_device = None | ||||||
| tl_extra_backend_module = None | ||||||
| ops_module = None | ||||||
| fused_module = None | ||||||
| heuristic_config_module = None | ||||||
| vendor_extra_lib_imported = False | ||||||
| device_fn_cache = {} | ||||||
| customized_ops = None | ||||||
|
|
||||||
| class BackendState: | ||||||
| """Singleton class to manage backend state variables.""" | ||||||
|
|
||||||
| _instance = None | ||||||
|
|
||||||
| def __new__(cls): | ||||||
| if cls._instance is None: | ||||||
| cls._instance = super().__new__(cls) | ||||||
| cls._instance._initialized = False | ||||||
| return cls._instance | ||||||
|
|
||||||
| def __init__(self): | ||||||
| if self._initialized: | ||||||
| return | ||||||
| self._initialized = True | ||||||
| self.vendor_module = None | ||||||
| self.device_name = None | ||||||
| self.torch_device_object = None | ||||||
| self.torch_device_fn_device = None | ||||||
| self.tl_extra_backend_module = None | ||||||
| self.ops_module = None | ||||||
| self.fused_module = None | ||||||
| self.heuristic_config_module = None | ||||||
| self.vendor_extra_lib_imported = False | ||||||
| self.device_fn_cache = {} | ||||||
| self.customized_ops = None | ||||||
|
|
||||||
|
|
||||||
| # Global singleton instance | ||||||
| _state = BackendState() | ||||||
|
|
||||||
|
|
||||||
| class BackendArchEvent: | ||||||
|
|
@@ -51,32 +71,29 @@ def get_functions_from_module(self, module): | |||||
| return inspect.getmembers(module, inspect.isfunction) if module else [] | ||||||
|
|
||||||
| def get_heuristics_configs(self): | ||||||
| heuristic_module = None | ||||||
| try: | ||||||
| heuristic_module = self.arch_module | ||||||
| except Exception: # noqa E722 | ||||||
| except Exception: | ||||||
| sys.path.insert(0, str(self.current_arch_path)) | ||||||
| heuristic_module = importlib.import_module("heuristics_config_utils") | ||||||
| sys.path.remove(str(self.current_arch_path)) | ||||||
| if hasattr(heuristic_module, "HEURISTICS_CONFIGS"): | ||||||
| return heuristic_module.HEURISTICS_CONFIGS | ||||||
| return None | ||||||
| return getattr(heuristic_module, "HEURISTICS_CONFIGS", None) | ||||||
Galaxy1458 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| def get_autotune_configs(self): | ||||||
| path = self.current_arch_path | ||||||
| return backend_utils.get_tune_config(file_path=path) | ||||||
|
|
||||||
| def get_arch(self, device=0): | ||||||
| if not hasattr(vendor_module, "ARCH_MAP"): | ||||||
| if not hasattr(_state.vendor_module, "ARCH_MAP"): | ||||||
| return | ||||||
| arch_map = vendor_module.ARCH_MAP | ||||||
| arch_map = _state.vendor_module.ARCH_MAP | ||||||
| arch_string = os.environ.get("ARCH", "") | ||||||
| arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string | ||||||
| if not arch_string_num: | ||||||
| try: | ||||||
| if not torch_device_object.is_available(): | ||||||
| if not _state.torch_device_object.is_available(): | ||||||
| return False | ||||||
| props = torch_device_object.get_device_properties(device) | ||||||
| props = _state.torch_device_object.get_device_properties(device) | ||||||
| arch_string_num = str(props.major) | ||||||
| except Exception: | ||||||
| self.has_arch = False | ||||||
|
|
@@ -89,16 +106,14 @@ def get_arch(self, device=0): | |||||
| return arch_map[arch_string_num] | ||||||
|
|
||||||
| def _get_supported_archs(self, path=None): | ||||||
| path = path or vendor_module.__path__[0] | ||||||
| excluded = ("ops", "fused") | ||||||
| path = Path(path) | ||||||
| path = Path(path or _state.vendor_module.__path__[0]) | ||||||
| path = path.parent if path.is_file() else path | ||||||
| archs = {} | ||||||
| for p in path.iterdir(): | ||||||
| name = str(p).split("/")[-1] | ||||||
| if p.is_dir() and name not in excluded and not name.startswith("_"): | ||||||
| archs.update({name: str(p)}) | ||||||
| return archs | ||||||
| excluded = ("ops", "fused") | ||||||
| return { | ||||||
| p.name: str(p) | ||||||
| for p in path.iterdir() | ||||||
| if p.is_dir() and p.name not in excluded and not p.name.startswith("_") | ||||||
| } | ||||||
|
|
||||||
| def get_supported_archs(self): | ||||||
| return list(self.supported_archs.keys()) | ||||||
|
|
@@ -113,51 +128,41 @@ def get_arch_module(self): | |||||
|
|
||||||
| def get_arch_ops(self): | ||||||
| arch_specialized_ops = [] | ||||||
| modules = [] | ||||||
| sys.path.append(self.current_arch_path) | ||||||
| ops_module = importlib.import_module(f"{self.arch}.ops") | ||||||
| ops_module = getattr(self.arch_module, "ops", None) | ||||||
| try: | ||||||
| ops_module = self.arch_module.ops | ||||||
| modules.append(ops_module) | ||||||
| except Exception: | ||||||
| try: | ||||||
| sys.path.append(self.current_arch_path) | ||||||
| if ops_module is None: | ||||||
| ops_module = importlib.import_module(f"{self.arch}.ops") | ||||||
Galaxy1458 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| modules.append(ops_module) | ||||||
| except Exception as err_msg: | ||||||
| self.error_msgs.append(err_msg) | ||||||
|
|
||||||
| for mod in modules: | ||||||
| arch_specialized_ops.extend(self.get_functions_from_module(mod)) | ||||||
|
|
||||||
| arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) | ||||||
| except Exception as err_msg: | ||||||
| self.error_msgs.append(err_msg) | ||||||
| return arch_specialized_ops | ||||||
|
|
||||||
|
|
||||||
| def import_vendor_extra_lib(vendor_name=None): | ||||||
| global vendor_extra_lib_imported | ||||||
| if vendor_extra_lib_imported is True: | ||||||
| return | ||||||
| global ops_module, fused_module | ||||||
| def _import_module_safe(module_name, vendor_name, module_type): | ||||||
| """Helper to import a module with proper error handling.""" | ||||||
| try: | ||||||
| ops_module = importlib.import_module(f"_{vendor_name}.ops") | ||||||
| return importlib.import_module(module_name) | ||||||
| except ModuleNotFoundError: | ||||||
| print( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider change this to logger.info ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even in that case, the error message are not supposed to be printed on stdout. By the way, it should be fine (not good though ...) if a vendor has no operators implemented. |
||||||
| f"[Note] No specialized common operators were found in" | ||||||
| f"the {vendor_name} implementation, and general common operators are used by default." | ||||||
| f"[Note] No specialized {module_type} operators were found for " | ||||||
| f"the {vendor_name}, generic {module_type} operators will be used by default." | ||||||
| ) | ||||||
| return None | ||||||
| except Exception as e: | ||||||
| raise RuntimeError(f"Import vendor extra lib failed: {e}") | ||||||
| raise RuntimeError(f"Failed to import vendor extra lib: {e}") | ||||||
|
|
||||||
| try: | ||||||
| fused_module = importlib.import_module(f"_{vendor_name}.fused") | ||||||
| except ModuleNotFoundError: | ||||||
| print( | ||||||
| f"[Note] No specialized fused operators were found in" | ||||||
| f"the {vendor_name} implementation, and general fused operators are used by default." | ||||||
| ) | ||||||
| except Exception as e: | ||||||
| raise RuntimeError(f"Import vendor extra lib failed: {e}") | ||||||
| vendor_extra_lib_imported = True | ||||||
|
|
||||||
| def import_vendor_extra_lib(vendor_name=None): | ||||||
| if _state.vendor_extra_lib_imported: | ||||||
| return | ||||||
| _state.ops_module = _import_module_safe( | ||||||
| f"_{vendor_name}.ops", vendor_name, "common" | ||||||
| ) | ||||||
| _state.fused_module = _import_module_safe( | ||||||
| f"_{vendor_name}.fused", vendor_name, "fused" | ||||||
| ) | ||||||
| _state.vendor_extra_lib_imported = True | ||||||
|
|
||||||
|
|
||||||
| def get_codegen_result(code, result_key): | ||||||
|
|
@@ -172,8 +177,7 @@ def get_codegen_result(code, result_key): | |||||
|
|
||||||
| @functools.lru_cache(maxsize=32) | ||||||
| def gen_torch_tensor_attr_res(tensor, attr_name): | ||||||
| global device_name | ||||||
| device_name = device_name or get_vendor_info().device_name | ||||||
| _state.device_name = _state.device_name or get_vendor_info().device_name | ||||||
| code = f""" | ||||||
| import torch | ||||||
| res = {tensor}.{attr_name} | ||||||
|
|
@@ -182,43 +186,40 @@ def gen_torch_tensor_attr_res(tensor, attr_name): | |||||
|
|
||||||
|
|
||||||
| def set_tl_extra_backend_module(vendor_name=None): | ||||||
| global device_name, tl_extra_backend_module | ||||||
| vendor_info = get_vendor_info(vendor_name) | ||||||
| device_name = device_name or vendor_info.device_name | ||||||
| extra_name = vendor_info.triton_extra_name or device_name | ||||||
| _state.device_name = _state.device_name or vendor_info.device_name | ||||||
| extra_name = vendor_info.triton_extra_name or _state.device_name | ||||||
| module_str = f"triton.language.extra.{extra_name}.libdevice" | ||||||
| tl_extra_backend_module = importlib.import_module(module_str) | ||||||
| _state.tl_extra_backend_module = importlib.import_module(module_str) | ||||||
|
|
||||||
|
|
||||||
| def get_tl_extra_backend_module(): | ||||||
| return tl_extra_backend_module | ||||||
| return _state.tl_extra_backend_module | ||||||
|
|
||||||
|
|
||||||
| def set_torch_backend_device_fn(vendor_name=None): | ||||||
| global device_name, torch_device_fn_device | ||||||
| device_name = device_name or get_vendor_info(vendor_name).device_name | ||||||
| module_str = f"torch.backends.{device_name}" | ||||||
| if device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): | ||||||
| torch_device_fn_device = None | ||||||
| _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name | ||||||
| module_str = f"torch.backends.{_state.device_name}" | ||||||
| if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): | ||||||
| _state.torch_device_fn_device = None | ||||||
|
Comment on lines
+203
to
+204
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like this logic ... it is hard to maintain.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean 'except: pass'. |
||||||
| else: | ||||||
| torch_device_fn_device = importlib.import_module(module_str) | ||||||
| _state.torch_device_fn_device = importlib.import_module(module_str) | ||||||
|
|
||||||
|
|
||||||
| def get_torch_backend_device_fn(): | ||||||
| return torch_device_fn_device | ||||||
| return _state.torch_device_fn_device | ||||||
|
|
||||||
|
|
||||||
| def gen_torch_device_object(vendor_name=None): | ||||||
| global device_name, torch_device_object | ||||||
| if torch_device_object is not None: | ||||||
| return torch_device_object | ||||||
| device_name = device_name or get_vendor_info(vendor_name).device_name | ||||||
| if _state.torch_device_object is not None: | ||||||
| return _state.torch_device_object | ||||||
| _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name | ||||||
| code = f""" | ||||||
| import torch | ||||||
| fn = torch.{device_name} | ||||||
| fn = torch.{_state.device_name} | ||||||
| """ | ||||||
| torch_device_object = get_codegen_result(code, "fn") | ||||||
| return torch_device_object | ||||||
| _state.torch_device_object = get_codegen_result(code, "fn") | ||||||
| return _state.torch_device_object | ||||||
|
|
||||||
|
|
||||||
| def get_vendor_module(vendor_name, query=False): | ||||||
|
|
@@ -233,73 +234,67 @@ def get_module(vendor_name): | |||||
| ): # The purpose of a query is to provide the user with the instance that he wants to import | ||||||
| return get_module(vendor_name) | ||||||
|
|
||||||
| global vendor_module | ||||||
| if vendor_module is None: | ||||||
| vendor_module = get_module("_" + vendor_name) | ||||||
| return vendor_module | ||||||
| if _state.vendor_module is None: | ||||||
| _state.vendor_module = get_module("_" + vendor_name) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in principle, there is no exception. If there is an exception for unknown reasons, it must be reported.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are in general two approaches ... |
||||||
| return _state.vendor_module | ||||||
|
|
||||||
|
|
||||||
| def get_vendor_info(vendor_name=None, query=False): | ||||||
| if query: | ||||||
| return get_vendor_module(vendor_name, query).vendor_info | ||||||
| global vendor_module # noqa: F824 | ||||||
| get_vendor_module(vendor_name) | ||||||
| return vendor_module.vendor_info | ||||||
| return _state.vendor_module.vendor_info | ||||||
|
|
||||||
|
|
||||||
| def get_vendor_infos(): | ||||||
| infos = [] | ||||||
| for vendor_name in vendors.get_all_vendors(): | ||||||
| vendor_name = "_" + vendor_name | ||||||
| try: | ||||||
| single_info = get_vendor_info(vendor_name, query=True) | ||||||
| infos.append(single_info) | ||||||
| infos.append(get_vendor_info(f"_{vendor_name}", query=True)) | ||||||
| except Exception: | ||||||
| pass | ||||||
|
|
||||||
| continue | ||||||
| return infos | ||||||
|
|
||||||
|
|
||||||
| def get_current_device_extend_op(vendor_name=None): | ||||||
| import_vendor_extra_lib(vendor_name) | ||||||
| global customized_ops | ||||||
| if customized_ops is not None: | ||||||
| return customized_ops | ||||||
| customized_ops = [] | ||||||
| if ops_module is not None: | ||||||
| ops = inspect.getmembers(ops_module, inspect.isfunction) | ||||||
| customized_ops += ops | ||||||
| if fused_module is not None: | ||||||
| fused_ops = inspect.getmembers(fused_module, inspect.isfunction) | ||||||
| customized_ops += fused_ops | ||||||
| return customized_ops | ||||||
| if _state.customized_ops is not None: | ||||||
| return _state.customized_ops | ||||||
| _state.customized_ops = [] | ||||||
| if _state.ops_module is not None: | ||||||
| ops = inspect.getmembers(_state.ops_module, inspect.isfunction) | ||||||
| _state.customized_ops += ops | ||||||
| if _state.fused_module is not None: | ||||||
| fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction) | ||||||
| _state.customized_ops += fused_ops | ||||||
| return _state.customized_ops | ||||||
|
|
||||||
|
|
||||||
| def get_curent_device_unused_op(vendor_name=None): | ||||||
| global vendor_module # noqa: F824 | ||||||
| get_vendor_module(vendor_name) | ||||||
| return list(vendor_module.CUSTOMIZED_UNUSED_OPS) | ||||||
| return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS) | ||||||
|
|
||||||
|
|
||||||
| def get_heuristic_config(vendor_name=None): | ||||||
| global heuristic_config_module | ||||||
| try: | ||||||
| heuristic_config_module = importlib.import_module( | ||||||
| f"_{vendor_name}.heuristics_config_utils" | ||||||
| ) | ||||||
| except: # noqa E722 | ||||||
| heuristic_config_module = importlib.import_module( | ||||||
| "_nvidia.heuristics_config_utils" | ||||||
| ) | ||||||
| if hasattr(heuristic_config_module, "HEURISTICS_CONFIGS"): | ||||||
| return heuristic_config_module.HEURISTICS_CONFIGS | ||||||
| return None | ||||||
| config_name = "heuristics_config_utils" | ||||||
| default_backend = "nvidia" | ||||||
| for backend in (vendor_name, default_backend): | ||||||
| mod_name = f"_{backend}.{config_name}" | ||||||
| try: | ||||||
| _state.heuristic_config_module = importlib.import_module(mod_name) | ||||||
| except Exception: | ||||||
| continue | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be 'pass'?
Suggested change
|
||||||
| return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None) | ||||||
|
|
||||||
|
|
||||||
| def get_tune_config(vendor_name=None): | ||||||
| global vendor_module # noqa: F824 | ||||||
| get_vendor_module(vendor_name) | ||||||
| return backend_utils.get_tune_config(vendor_name) | ||||||
|
|
||||||
|
|
||||||
| def get_backend_state() -> BackendState: | ||||||
| """Get the global BackendState singleton instance.""" | ||||||
| return _state | ||||||
|
|
||||||
|
|
||||||
| __all__ = ["*"] | ||||||
Uh oh!
There was an error while loading. Please reload this page.