Skip to content

Commit 6dbed69

Browse files
authored
[rocm, test] feat: add ROCm controller and test tooling (#56)
* Implement ROCm GPU controller and tests * Gate ROCm tests behind marker/flag * Add developer testing notes for CUDA/ROCm * Refactor GPU controllers and ROCm tests * Cleanup ROCm controller tweaks and tests
1 parent a40eb74 commit 6dbed69

9 files changed

Lines changed: 268 additions & 26 deletions

File tree

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy
8282
- CLI + API parity: same controllers power both code paths.
8383
- Continuous docs + CI: mkdocs + mkdocstrings build in CI to keep guidance up to date.
8484

85+
## For developers
86+
87+
- Install dev extras: `pip install -e ".[dev]"` (add `.[rocm]` if you need ROCm SMI).
88+
- Fast CUDA checks: `pytest tests/cuda_controller tests/global_controller tests/utilities/test_platform_manager.py tests/test_cli_thresholds.py`
89+
- ROCm-only tests carry `@pytest.mark.rocm`; run with `pytest --run-rocm tests/rocm_controller`.
90+
- Markers: `rocm` (needs ROCm stack) and `large_memory` (opt-in locally).
91+
8592
## Contributing
8693

8794
Contributions are welcome—especially around ROCm support, platform fallbacks, and scheduler-specific recipes. Open an issue or PR if you hit edge cases on your cluster.

docs/getting-started.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ understand the minimum knobs you need to keep a GPU occupied.
3939
pip install keep-gpu
4040
```
4141

42+
## For contributors
43+
44+
- Install dev extras: `pip install -e ".[dev]"` (append `.[rocm]` if you need ROCm SMI).
45+
- Fast CUDA checks: `pytest tests/cuda_controller tests/global_controller tests/utilities/test_platform_manager.py tests/test_cli_thresholds.py`
46+
- ROCm-only tests are marked `rocm`; run with `pytest --run-rocm tests/rocm_controller`.
47+
4248
=== "Editable dev install"
4349
```bash
4450
git clone https://github.com/Wangmerlyn/KeepGPU.git

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,9 @@ exclude = ["build", "dist", ".venv"]
139139
known-first-party = ["keep_gpu"]
140140
combine-as-imports = true
141141
force-single-line = false
142+
143+
[tool.pytest.ini_options]
144+
markers = [
145+
"rocm: tests that require ROCm stack",
146+
"large_memory: tests that use large VRAM",
147+
]

src/keep_gpu/global_gpu_controller/global_gpu_controller.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,33 @@ def __init__(
3030
CudaGPUController,
3131
)
3232

33-
if gpu_ids is None:
34-
self.gpu_ids = list(range(torch.cuda.device_count()))
35-
else:
36-
self.gpu_ids = gpu_ids
33+
controller_cls = CudaGPUController
34+
elif self.computing_platform == ComputingPlatform.ROCM:
35+
from keep_gpu.single_gpu_controller.rocm_gpu_controller import (
36+
RocmGPUController,
37+
)
3738

38-
self.controllers = [
39-
CudaGPUController(
40-
rank=i,
41-
interval=interval,
42-
vram_to_keep=vram_to_keep,
43-
busy_threshold=busy_threshold,
44-
)
45-
for i in self.gpu_ids
46-
]
39+
controller_cls = RocmGPUController
4740
else:
4841
raise NotImplementedError(
4942
f"GlobalGPUController not implemented for platform {self.computing_platform}"
5043
)
5144

45+
if gpu_ids is None:
46+
self.gpu_ids = list(range(torch.cuda.device_count()))
47+
else:
48+
self.gpu_ids = gpu_ids
49+
50+
self.controllers = [
51+
controller_cls(
52+
rank=i,
53+
interval=interval,
54+
vram_to_keep=vram_to_keep,
55+
busy_threshold=busy_threshold,
56+
)
57+
for i in self.gpu_ids
58+
]
59+
5260
def keep(self) -> None:
5361
for ctrl in self.controllers:
5462
ctrl.keep()

src/keep_gpu/single_gpu_controller/base_gpu_controller.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1+
from typing import Union
2+
3+
from keep_gpu.utilities.humanized_input import parse_size
4+
5+
16
class BaseGPUController:
2-
def __init__(self, vram_to_keep: int, interval: float):
7+
def __init__(self, vram_to_keep: Union[int, str], interval: float):
38
"""
49
Base class for GPU controllers.
510
611
Args:
7-
vram_to_keep (int): Amount of VRAM (in MB) to keep free.
8-
interval (int): Time interval (in seconds) for checks or actions.
9-
"""
12+
vram_to_keep (int or str): Amount of VRAM to keep busy. Accepts integers
13+
(tensor element count) or human strings like "1GiB" (converted to
14+
element count for float32 tensors).
15+
interval (float): Time interval (in seconds) between keep-alive cycles.
16+
"""
17+
if isinstance(vram_to_keep, str):
18+
vram_to_keep = parse_size(vram_to_keep)
19+
elif not isinstance(vram_to_keep, int):
20+
raise TypeError(
21+
f"vram_to_keep must be str or int, got {type(vram_to_keep)}"
22+
)
1023
self.vram_to_keep = vram_to_keep
1124
self.interval = interval
1225

src/keep_gpu/single_gpu_controller/cuda_gpu_controller.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,6 @@ def __init__(
6161
hogging the GPU.
6262
6363
"""
64-
if isinstance(vram_to_keep, str):
65-
vram_to_keep = self.parse_size(vram_to_keep)
66-
elif isinstance(vram_to_keep, int):
67-
vram_to_keep = vram_to_keep
68-
else:
69-
raise TypeError(
70-
f"vram_to_keep must be str or int, got {type(vram_to_keep)}"
71-
)
7264
super().__init__(vram_to_keep=vram_to_keep, interval=interval)
7365
self.rank = rank
7466
self.device = torch.device(f"cuda:{rank}")
@@ -185,7 +177,7 @@ def _run_mat_batch(self, matrix: torch.Tensor) -> None:
185177
toc = time.time()
186178

187179
logger.debug(
188-
"rank %s: mat ops batch done avg %.2f ms",
180+
"rank %s: mat ops batch done - avg %.2f ms",
189181
self.rank,
190182
(toc - tic) * 1000 / self.matmul_iterations,
191183
)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import threading
2+
import time
3+
from typing import Optional
4+
5+
import torch
6+
7+
from keep_gpu.single_gpu_controller.base_gpu_controller import BaseGPUController
8+
from keep_gpu.utilities.logger import setup_logger
9+
10+
logger = setup_logger(__name__)
11+
12+
13+
class RocmGPUController(BaseGPUController):
14+
"""
15+
Keep a single ROCm GPU busy by running lightweight elementwise ops
16+
in a background thread. Requires a ROCm-enabled torch build.
17+
"""
18+
19+
def __init__(
20+
self,
21+
*,
22+
rank: int,
23+
interval: float = 1.0,
24+
vram_to_keep: str | int = "1000 MB",
25+
busy_threshold: int = 10,
26+
iterations: int = 5000,
27+
):
28+
super().__init__(vram_to_keep=vram_to_keep, interval=interval)
29+
self.rank = rank
30+
self.device = torch.device(f"cuda:{rank}")
31+
self.busy_threshold = busy_threshold
32+
self.iterations = iterations
33+
self._stop_evt: Optional[threading.Event] = None
34+
self._thread: Optional[threading.Thread] = None
35+
36+
# Lazy rocm_smi import; keep handle for reuse
37+
try:
38+
import rocm_smi # type: ignore
39+
40+
self._rocm_smi = rocm_smi
41+
except Exception as exc: # pragma: no cover - env-specific
42+
logger.debug("rocm_smi not available: %s", exc)
43+
self._rocm_smi = None
44+
45+
def keep(self) -> None:
46+
if self._thread and self._thread.is_alive():
47+
logger.warning("rank %s: keep thread already running", self.rank)
48+
return
49+
if self._rocm_smi:
50+
try:
51+
self._rocm_smi.rsmi_init()
52+
except Exception as exc: # pragma: no cover - env-specific
53+
logger.debug("rsmi_init failed: %s", exc)
54+
55+
self._stop_evt = threading.Event()
56+
self._thread = threading.Thread(
57+
target=self._keep_loop,
58+
name=f"gpu-keeper-rocm-{self.rank}",
59+
daemon=True,
60+
)
61+
self._thread.start()
62+
logger.info("rank %s: ROCm keep thread started", self.rank)
63+
64+
def release(self) -> None:
65+
if not (self._thread and self._thread.is_alive()):
66+
logger.warning("rank %s: keep thread not running", self.rank)
67+
return
68+
self._stop_evt.set()
69+
self._thread.join()
70+
torch.cuda.empty_cache()
71+
if self._rocm_smi:
72+
try:
73+
self._rocm_smi.rsmi_shut_down()
74+
except Exception as exc: # pragma: no cover - best effort
75+
logger.debug("rsmi_shut_down failed: %s", exc)
76+
logger.info("rank %s: keep thread stopped & cache cleared", self.rank)
77+
78+
def __enter__(self):
79+
self.keep()
80+
return self
81+
82+
def __exit__(self, exc_type, exc, tb):
83+
self.release()
84+
85+
def _query_utilization(self) -> Optional[int]:
86+
if not self._rocm_smi:
87+
return None
88+
try:
89+
util = self._rocm_smi.rsmi_dev_busy_percent_get(self.rank)
90+
return int(util)
91+
except Exception as exc: # pragma: no cover - env-specific
92+
logger.debug("ROCm utilization query failed: %s", exc)
93+
return None
94+
95+
def _keep_loop(self) -> None:
96+
torch.cuda.set_device(self.rank)
97+
tensor = None
98+
while not self._stop_evt.is_set():
99+
try:
100+
tensor = torch.rand(
101+
self.vram_to_keep,
102+
device=self.device,
103+
dtype=torch.float32,
104+
requires_grad=False,
105+
)
106+
break
107+
except RuntimeError:
108+
logger.exception("rank %s: failed to allocate tensor", self.rank)
109+
time.sleep(self.interval)
110+
if tensor is None:
111+
logger.error("rank %s: failed to allocate tensor, exiting loop", self.rank)
112+
raise RuntimeError("Failed to allocate tensor for ROCm GPU keeping")
113+
114+
while not self._stop_evt.is_set():
115+
try:
116+
util = self._query_utilization()
117+
if util is not None and util > self.busy_threshold:
118+
logger.debug("rank %s: GPU busy (%d%%), sleeping", self.rank, util)
119+
else:
120+
self._run_batch(tensor)
121+
time.sleep(self.interval)
122+
except RuntimeError as exc:
123+
if "out of memory" in str(exc).lower():
124+
torch.cuda.empty_cache()
125+
time.sleep(self.interval)
126+
except Exception:
127+
logger.exception("rank %s: unexpected error", self.rank)
128+
time.sleep(self.interval)
129+
130+
@torch.no_grad()
131+
def _run_batch(self, tensor: torch.Tensor) -> None:
132+
tic = time.time()
133+
for _ in range(self.iterations):
134+
torch.relu_(tensor)
135+
if self._stop_evt.is_set():
136+
break
137+
torch.cuda.synchronize()
138+
toc = time.time()
139+
logger.debug(
140+
"rank %s: elementwise batch done - avg %.2f ms",
141+
self.rank,
142+
(toc - tic) * 1000 / max(1, self.iterations),
143+
)

tests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
4+
5+
def pytest_addoption(parser):
6+
parser.addoption(
7+
"--run-rocm",
8+
action="store_true",
9+
default=False,
10+
help="run tests marked as rocm (require ROCm stack)",
11+
)
12+
13+
14+
def pytest_configure(config):
15+
config.addinivalue_line("markers", "rocm: tests that require ROCm stack")
16+
config.addinivalue_line("markers", "large_memory: tests that use large VRAM")
17+
18+
19+
def pytest_collection_modifyitems(config, items):
20+
if config.getoption("--run-rocm"):
21+
return
22+
23+
skip_rocm = pytest.mark.skip(reason="need --run-rocm option to run")
24+
for item in items:
25+
if "rocm" in item.keywords:
26+
item.add_marker(skip_rocm)
27+
28+
29+
@pytest.fixture
30+
def rocm_available():
31+
return bool(torch.cuda.is_available() and getattr(torch.version, "hip", None))
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import sys
2+
3+
import pytest
4+
5+
from keep_gpu.single_gpu_controller import rocm_gpu_controller as rgc
6+
7+
8+
@pytest.mark.rocm
9+
def test_query_rocm_utilization_with_mock(monkeypatch, rocm_available):
10+
if not rocm_available:
11+
pytest.skip("ROCm stack not available")
12+
13+
class DummyRocmSMI:
14+
calls = 0
15+
16+
@classmethod
17+
def rsmi_init(cls):
18+
cls.calls += 1
19+
20+
@staticmethod
21+
def rsmi_dev_busy_percent_get(index):
22+
assert index == 1
23+
return 42
24+
25+
@classmethod
26+
def rsmi_shut_down(cls):
27+
cls.calls += 1
28+
29+
# Ensure the counter is reset to avoid leaking state between tests
30+
DummyRocmSMI.calls = 0
31+
monkeypatch.setitem(sys.modules, "rocm_smi", DummyRocmSMI)
32+
util = rgc._query_rocm_utilization(1)
33+
assert util == 42
34+
assert DummyRocmSMI.calls == 2 # init + shutdown
35+
# Reset after test for cleanliness
36+
DummyRocmSMI.calls = 0

0 commit comments

Comments
 (0)