From 96927215373a18788eed5d2ee46cf3249b345b93 Mon Sep 17 00:00:00 2001 From: Jerry Mannil Date: Tue, 6 Jan 2026 22:42:10 +0000 Subject: [PATCH] Add AMD gpu support --- byte_micro_perf/backends/GPU/backend_gpu.py | 37 ++++++++++++++------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index e1e88a5d..49044469 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -72,21 +72,34 @@ def empty_cache(self): def get_backend_env(self): __torch_version = torch.__version__ - __cuda_version = torch.version.cuda + __torch_device_type = "" + __torch_device_version = "" __driver_version = '' - nvidia_smi_output = subprocess.run( - ['nvidia-smi', '-q', '-i', '0'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - for line in nvidia_smi_output.stdout.split('\n'): - if 'Driver Version' in line: - __driver_version = line.split(':')[1].strip() - break + __smi_cmd = [] + + if torch.version.cuda: + __torch_device_type = "torch_cuda" + __torch_device_version = torch.version.cuda + __smi_cmd = ['nvidia-smi', '-q', '-i', '0'] + elif torch.version.hip: + __torch_device_type = "torch_hip" + __torch_device_version = torch.version.hip + __smi_cmd = ['rocm-smi', '--showdriverversion'] + + if __smi_cmd: + smi_output = subprocess.run( + __smi_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + for line in smi_output.stdout.split('\n'): + if 'driver version' in line.lower(): + __driver_version = line.split(':')[1].strip() + break return { "torch": __torch_version, - "torch_cuda": __cuda_version, + __torch_device_type: __torch_device_version, "driver": __driver_version, }