diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index 9ad84fee..c6ee78d4 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -45,25 +45,36 @@ def get_backend_info(self): device_properties = torch.cuda.get_device_properties(0) info_dict["device_memory_mb"] = device_properties.total_memory / (1024 ** 2) - - __torch_version = torch.__version__ - __cuda_version = torch.version.cuda + __torch_device_type = "" + __torch_device_version = "" __driver_version = '' - nvidia_smi_output = subprocess.run( - ['nvidia-smi', '-q', '-i', '0'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - for line in nvidia_smi_output.stdout.split('\n'): - if 'Driver Version' in line: - __driver_version = line.split(':')[1].strip() - break + __smi_cmd = [] + + if torch.version.cuda: + __torch_device_type = "torch_cuda" + __torch_device_version = torch.version.cuda + __smi_cmd = ['nvidia-smi', '-q', '-i', '0'] + elif torch.version.hip: + __torch_device_type = "torch_hip" + __torch_device_version = torch.version.hip + __smi_cmd = ['rocm-smi', '--showdriverversion'] + + if __smi_cmd: + smi_output = subprocess.run( + __smi_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + for line in smi_output.stdout.split('\n'): + if 'driver version' in line.lower(): + __driver_version = line.split(':')[1].strip() + break info_dict["torch_version"] = __torch_version - info_dict["torch_cuda_version"] = __cuda_version + info_dict[__torch_device_type] = __torch_device_version info_dict["driver_version"] = __driver_version return info_dict