Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 115 additions & 120 deletions src/flag_gems/runtime/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,37 @@
from ..common import vendors
from . import backend_utils

vendor_module = None
device_name = None
torch_device_object = None
torch_device_fn_device = None
tl_extra_backend_module = None
ops_module = None
fused_module = None
heuristic_config_module = None
vendor_extra_lib_imported = False
device_fn_cache = {}
customized_ops = None

class BackendState:
"""Singleton class to manage backend state variables."""

_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance

def __init__(self):
if self._initialized:
return
self._initialized = True
self.vendor_module = None
self.device_name = None
self.torch_device_object = None
self.torch_device_fn_device = None
self.tl_extra_backend_module = None
self.ops_module = None
self.fused_module = None
self.heuristic_config_module = None
self.vendor_extra_lib_imported = False
self.device_fn_cache = {}
self.customized_ops = None


# Global singleton instance
_state = BackendState()


class BackendArchEvent:
Expand Down Expand Up @@ -51,32 +71,29 @@ def get_functions_from_module(self, module):
return inspect.getmembers(module, inspect.isfunction) if module else []

def get_heuristics_configs(self):
heuristic_module = None
try:
heuristic_module = self.arch_module
except Exception: # noqa E722
except Exception:
sys.path.insert(0, str(self.current_arch_path))
heuristic_module = importlib.import_module("heuristics_config_utils")
sys.path.remove(str(self.current_arch_path))
if hasattr(heuristic_module, "HEURISTICS_CONFIGS"):
return heuristic_module.HEURISTICS_CONFIGS
return None
return getattr(heuristic_module, "HEURISTICS_CONFIGS", None)

def get_autotune_configs(self):
path = self.current_arch_path
return backend_utils.get_tune_config(file_path=path)

def get_arch(self, device=0):
if not hasattr(vendor_module, "ARCH_MAP"):
if not hasattr(_state.vendor_module, "ARCH_MAP"):
return
arch_map = vendor_module.ARCH_MAP
arch_map = _state.vendor_module.ARCH_MAP
arch_string = os.environ.get("ARCH", "")
arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string
if not arch_string_num:
try:
if not torch_device_object.is_available():
if not _state.torch_device_object.is_available():
return False
props = torch_device_object.get_device_properties(device)
props = _state.torch_device_object.get_device_properties(device)
arch_string_num = str(props.major)
except Exception:
self.has_arch = False
Expand All @@ -89,16 +106,14 @@ def get_arch(self, device=0):
return arch_map[arch_string_num]

def _get_supported_archs(self, path=None):
path = path or vendor_module.__path__[0]
excluded = ("ops", "fused")
path = Path(path)
path = Path(path or _state.vendor_module.__path__[0])
path = path.parent if path.is_file() else path
archs = {}
for p in path.iterdir():
name = str(p).split("/")[-1]
if p.is_dir() and name not in excluded and not name.startswith("_"):
archs.update({name: str(p)})
return archs
excluded = ("ops", "fused")
return {
p.name: str(p)
for p in path.iterdir()
if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
}

def get_supported_archs(self):
return list(self.supported_archs.keys())
Expand All @@ -113,51 +128,41 @@ def get_arch_module(self):

def get_arch_ops(self):
arch_specialized_ops = []
modules = []
sys.path.append(self.current_arch_path)
ops_module = importlib.import_module(f"{self.arch}.ops")
ops_module = getattr(self.arch_module, "ops", None)
try:
ops_module = self.arch_module.ops
modules.append(ops_module)
except Exception:
try:
sys.path.append(self.current_arch_path)
if ops_module is None:
ops_module = importlib.import_module(f"{self.arch}.ops")
modules.append(ops_module)
except Exception as err_msg:
self.error_msgs.append(err_msg)

for mod in modules:
arch_specialized_ops.extend(self.get_functions_from_module(mod))

arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
except Exception as err_msg:
self.error_msgs.append(err_msg)
return arch_specialized_ops


def import_vendor_extra_lib(vendor_name=None):
global vendor_extra_lib_imported
if vendor_extra_lib_imported is True:
return
global ops_module, fused_module
def _import_module_safe(module_name, vendor_name, module_type):
"""Helper to import a module with proper error handling."""
try:
ops_module = importlib.import_module(f"_{vendor_name}.ops")
return importlib.import_module(module_name)
except ModuleNotFoundError:
print(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider change this to logger.info ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If except is entered, this [note] need to be prompted to the user in any case.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even in that case, the error message are not supposed to be printed on stdout.
This module is in the core of the project, and the whole project can be scripted.
I have noticed this unconditional print when writing test scripts.

By the way, it should be fine (not good though ...) if a vendor has no operators implemented.

f"[Note] No specialized common operators were found in"
f"the {vendor_name} implementation, and general common operators are used by default."
f"[Note] No specialized {module_type} operators were found for "
f"the {vendor_name}, generic {module_type} operators will be used by default."
)
return None
except Exception as e:
raise RuntimeError(f"Import vendor extra lib failed: {e}")
raise RuntimeError(f"Failed to import vendor extra lib: {e}")

try:
fused_module = importlib.import_module(f"_{vendor_name}.fused")
except ModuleNotFoundError:
print(
f"[Note] No specialized fused operators were found in"
f"the {vendor_name} implementation, and general fused operators are used by default."
)
except Exception as e:
raise RuntimeError(f"Import vendor extra lib failed: {e}")
vendor_extra_lib_imported = True

def import_vendor_extra_lib(vendor_name=None):
if _state.vendor_extra_lib_imported:
return
_state.ops_module = _import_module_safe(
f"_{vendor_name}.ops", vendor_name, "common"
)
_state.fused_module = _import_module_safe(
f"_{vendor_name}.fused", vendor_name, "fused"
)
_state.vendor_extra_lib_imported = True


def get_codegen_result(code, result_key):
Expand All @@ -172,8 +177,7 @@ def get_codegen_result(code, result_key):

@functools.lru_cache(maxsize=32)
def gen_torch_tensor_attr_res(tensor, attr_name):
global device_name
device_name = device_name or get_vendor_info().device_name
_state.device_name = _state.device_name or get_vendor_info().device_name
code = f"""
import torch
res = {tensor}.{attr_name}
Expand All @@ -182,43 +186,40 @@ def gen_torch_tensor_attr_res(tensor, attr_name):


def set_tl_extra_backend_module(vendor_name=None):
global device_name, tl_extra_backend_module
vendor_info = get_vendor_info(vendor_name)
device_name = device_name or vendor_info.device_name
extra_name = vendor_info.triton_extra_name or device_name
_state.device_name = _state.device_name or vendor_info.device_name
extra_name = vendor_info.triton_extra_name or _state.device_name
module_str = f"triton.language.extra.{extra_name}.libdevice"
tl_extra_backend_module = importlib.import_module(module_str)
_state.tl_extra_backend_module = importlib.import_module(module_str)


def get_tl_extra_backend_module():
return tl_extra_backend_module
return _state.tl_extra_backend_module


def set_torch_backend_device_fn(vendor_name=None):
global device_name, torch_device_fn_device
device_name = device_name or get_vendor_info(vendor_name).device_name
module_str = f"torch.backends.{device_name}"
if device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
torch_device_fn_device = None
_state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
module_str = f"torch.backends.{_state.device_name}"
if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
_state.torch_device_fn_device = None
Comment on lines +203 to +204
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this logic ... it is hard to maintain.
Please consider using a try ... except struct here...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using a try ... except struct here doesn‘t conform to the original intention of the code. torch_device_fn_device is not necessarily non-existent. It could be that the vendor doesn't want to use it, or it might not be needed at all

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean 'except: pass'.
Don't enumerate the device name here ...
no one can remember that we planted some hardcoded device specific code here.

else:
torch_device_fn_device = importlib.import_module(module_str)
_state.torch_device_fn_device = importlib.import_module(module_str)


def get_torch_backend_device_fn():
return torch_device_fn_device
return _state.torch_device_fn_device


def gen_torch_device_object(vendor_name=None):
global device_name, torch_device_object
if torch_device_object is not None:
return torch_device_object
device_name = device_name or get_vendor_info(vendor_name).device_name
if _state.torch_device_object is not None:
return _state.torch_device_object
_state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
code = f"""
import torch
fn = torch.{device_name}
fn = torch.{_state.device_name}
"""
torch_device_object = get_codegen_result(code, "fn")
return torch_device_object
_state.torch_device_object = get_codegen_result(code, "fn")
return _state.torch_device_object


def get_vendor_module(vendor_name, query=False):
Expand All @@ -233,73 +234,67 @@ def get_module(vendor_name):
): # The purpose of a query is to provide the user with the instance that he wants to import
return get_module(vendor_name)

global vendor_module
if vendor_module is None:
vendor_module = get_module("_" + vendor_name)
return vendor_module
if _state.vendor_module is None:
_state.vendor_module = get_module("_" + vendor_name)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this get_module is exception free?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in principle, there is no exception. If there is an exception for unknown reasons, it must be reported.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are in general two approaches ...
either you catch all exceptions in the get_module so the function never throws an exception, and you may get an empty string here (check it);
or you allow it to throw an exception and you will catch that exception in the call chain.

return _state.vendor_module


def get_vendor_info(vendor_name=None, query=False):
if query:
return get_vendor_module(vendor_name, query).vendor_info
global vendor_module # noqa: F824
get_vendor_module(vendor_name)
return vendor_module.vendor_info
return _state.vendor_module.vendor_info


def get_vendor_infos():
infos = []
for vendor_name in vendors.get_all_vendors():
vendor_name = "_" + vendor_name
try:
single_info = get_vendor_info(vendor_name, query=True)
infos.append(single_info)
infos.append(get_vendor_info(f"_{vendor_name}", query=True))
except Exception:
pass

continue
return infos


def get_current_device_extend_op(vendor_name=None):
import_vendor_extra_lib(vendor_name)
global customized_ops
if customized_ops is not None:
return customized_ops
customized_ops = []
if ops_module is not None:
ops = inspect.getmembers(ops_module, inspect.isfunction)
customized_ops += ops
if fused_module is not None:
fused_ops = inspect.getmembers(fused_module, inspect.isfunction)
customized_ops += fused_ops
return customized_ops
if _state.customized_ops is not None:
return _state.customized_ops
_state.customized_ops = []
if _state.ops_module is not None:
ops = inspect.getmembers(_state.ops_module, inspect.isfunction)
_state.customized_ops += ops
if _state.fused_module is not None:
fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction)
_state.customized_ops += fused_ops
return _state.customized_ops


def get_curent_device_unused_op(vendor_name=None):
global vendor_module # noqa: F824
get_vendor_module(vendor_name)
return list(vendor_module.CUSTOMIZED_UNUSED_OPS)
return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS)


def get_heuristic_config(vendor_name=None):
global heuristic_config_module
try:
heuristic_config_module = importlib.import_module(
f"_{vendor_name}.heuristics_config_utils"
)
except: # noqa E722
heuristic_config_module = importlib.import_module(
"_nvidia.heuristics_config_utils"
)
if hasattr(heuristic_config_module, "HEURISTICS_CONFIGS"):
return heuristic_config_module.HEURISTICS_CONFIGS
return None
config_name = "heuristics_config_utils"
default_backend = "nvidia"
for backend in (vendor_name, default_backend):
mod_name = f"_{backend}.{config_name}"
try:
_state.heuristic_config_module = importlib.import_module(mod_name)
except Exception:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be 'pass'?

Suggested change
continue
pass

return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None)


def get_tune_config(vendor_name=None):
global vendor_module # noqa: F824
get_vendor_module(vendor_name)
return backend_utils.get_tune_config(vendor_name)


def get_backend_state() -> BackendState:
"""Get the global BackendState singleton instance."""
return _state


__all__ = ["*"]
Loading
Loading