Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.13'

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:

- name: Run pre-commit (no fix)
run: |
pre-commit run --all-files --hook-stage manual --show-diff-on-failure --color always
pre-commit run --all-files --show-diff-on-failure --color always
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install flake8 pytest
pip install torch --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
4 changes: 3 additions & 1 deletion src/keep_gpu/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def keep(rank, args):
toc = time.time()

logger.info(
f"benchmark {rank} matmul: time span: {(toc - tic) * 1000 / 5000:.2f}ms"
"benchmark %s matmul: time span: %.2fms",
rank,
(toc - tic) * 1000 / args.matmul_iterations,
)

time.sleep(args.interval)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
controller_cls(
rank=i,
interval=interval,
vram_to_keep=vram_to_keep,
vram_to_keep=self.vram_to_keep,
busy_threshold=busy_threshold,
)
for i in self.gpu_ids
Expand Down
4 changes: 2 additions & 2 deletions src/keep_gpu/single_gpu_controller/base_gpu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def monitor(self):

def keep(self):
"""
Method to keep the specified amount of VRAM free.
Method to keep the specified amount of VRAM busy/occupied.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclasses must implement this method.")
Expand All @@ -46,7 +46,7 @@ def rest(self):

async def _keep(self):
"""
Asynchronous method to keep the specified amount of VRAM free.
Asynchronous method to keep the specified amount of VRAM busy/occupied.
This is a placeholder for subclasses to implement their logic.
"""
raise NotImplementedError("Subclasses must implement this method.")
Expand Down
9 changes: 6 additions & 3 deletions src/keep_gpu/single_gpu_controller/cuda_gpu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
Args:
rank (int): Local CUDA device index to occupy.
interval (float, optional): Sleep time (seconds) between workload
batches. Defaults to 0.5.
batches. Defaults to 1.0.
matmul_iterations (int, optional): Number of matmul ops per batch.
vram_to_keep (int or str, optional): Amount of VRAM to keep busy,
e.g. `"1000 MB"`, `"20 GB"`, or an integer like `1000 * 1000`.
Expand Down Expand Up @@ -126,8 +126,11 @@ def _keep_loop(self) -> None:
matrix = None
while not self._stop_evt.is_set():
try:
num_elements = int(self.vram_to_keep)
if num_elements <= 0:
raise ValueError("vram_to_keep must be positive")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Catch invalid VRAM sizes before worker thread starts

The new num_elements <= 0 guard raises ValueError inside _keep_loop, but this loop only catches RuntimeError, so keep() can report success and then the background thread immediately dies when vram_to_keep resolves to 0 (for example --vram 0 from CLI/API). This leaves the controller inactive without a synchronous failure path; validation should happen before starting the thread or ValueError should be handled and surfaced.

Useful? React with 👍 / 👎.

matrix = torch.rand(
self.vram_to_keep,
num_elements,
device=self.device,
dtype=torch.float32,
requires_grad=False,
Expand Down Expand Up @@ -166,7 +169,7 @@ def _keep_loop(self) -> None:
# ------------------------------------------------------------------
@torch.no_grad()
def _run_mat_batch(self, matrix: torch.Tensor) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method name _run_mat_batch is misleading as the implementation uses torch.relu_ instead of a matrix multiplication. The updated docstring correctly identifies the operation as ReLU, which makes the method name even more confusing.

To improve clarity and maintainability, consider renaming this method to something like _run_relu_batch.

Additionally, the parameter matmul_iterations and the debug log message "mat ops batch done" are also inconsistent with the ReLU operation and should be updated accordingly throughout the class.

"""Run a batch of dummy matmuls to keep GPU busy."""
"""Run a batch of in-place ReLU ops to keep GPU busy."""

tic = time.time()
for _ in range(self.matmul_iterations):
Expand Down
44 changes: 33 additions & 11 deletions src/keep_gpu/single_gpu_controller/rocm_gpu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ def __init__(
vram_to_keep: str | int = "1000 MB",
busy_threshold: int = 10,
iterations: int = 5000,
max_allocation_retries: int = 3,
):
super().__init__(vram_to_keep=vram_to_keep, interval=interval)
self.rank = rank
self.device = torch.device(f"cuda:{rank}")
self.busy_threshold = busy_threshold
self.iterations = iterations
self.max_allocation_retries = max_allocation_retries
self._stop_evt: Optional[threading.Event] = None
self._thread: Optional[threading.Thread] = None
self._failure_exc: Optional[Exception] = None

# Lazy rocm_smi import; keep handle for reuse
try:
Expand All @@ -46,6 +49,7 @@ def keep(self) -> None:
if self._thread and self._thread.is_alive():
logger.warning("rank %s: keep thread already running", self.rank)
return
self._failure_exc = None
if self._rocm_smi:
try:
self._rocm_smi.rsmi_init()
Expand All @@ -62,12 +66,12 @@ def keep(self) -> None:
logger.info("rank %s: ROCm keep thread started", self.rank)

def release(self) -> None:
if not (self._thread and self._thread.is_alive()):
if self._thread and self._thread.is_alive():
self._stop_evt.set()
self._thread.join()
torch.cuda.empty_cache()
else:
logger.warning("rank %s: keep thread not running", self.rank)
return
self._stop_evt.set()
self._thread.join()
torch.cuda.empty_cache()
if self._rocm_smi:
try:
self._rocm_smi.rsmi_shut_down()
Expand Down Expand Up @@ -95,21 +99,35 @@ def _query_utilization(self) -> Optional[int]:
def _keep_loop(self) -> None:
torch.cuda.set_device(self.rank)
tensor = None
while not self._stop_evt.is_set():
attempts = 0
while not self._stop_evt.is_set() and attempts < self.max_allocation_retries:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep retrying ROCm allocation instead of exiting early

The startup allocation loop now stops after max_allocation_retries attempts and returns, which means the keeper thread permanently exits if VRAM is temporarily unavailable at startup; if memory frees up moments later, this controller never resumes unless the caller manually restarts it. This is a functional regression from the previous indefinite retry behavior and can cause keep-alive to silently stop in transiently busy environments.

Useful? React with 👍 / 👎.

try:
num_elements = int(self.vram_to_keep)
if num_elements <= 0:
raise ValueError("vram_to_keep must be positive")
tensor = torch.rand(
self.vram_to_keep,
num_elements,
device=self.device,
dtype=torch.float32,
requires_grad=False,
)
break
except RuntimeError:
logger.exception("rank %s: failed to allocate tensor", self.rank)
except (RuntimeError, ValueError) as exc:
attempts += 1
logger.error(
"rank %s: failed to allocate tensor (attempt %d/%d): %s",
self.rank,
attempts,
self.max_allocation_retries,
exc,
)
time.sleep(self.interval)
if tensor is None:
logger.error("rank %s: failed to allocate tensor, exiting loop", self.rank)
raise RuntimeError("Failed to allocate tensor for ROCm GPU keeping")
self._failure_exc = RuntimeError(
f"rank {self.rank}: failed to allocate tensor after {attempts} attempts"
)
logger.error("%s", self._failure_exc)
return

while not self._stop_evt.is_set():
try:
Expand All @@ -127,6 +145,10 @@ def _keep_loop(self) -> None:
logger.exception("rank %s: unexpected error", self.rank)
time.sleep(self.interval)

def allocation_status(self) -> Optional[Exception]:
"""Return allocation failure exception captured in worker thread, if any."""
return self._failure_exc

@torch.no_grad()
def _run_batch(self, tensor: torch.Tensor) -> None:
tic = time.time()
Expand Down
11 changes: 10 additions & 1 deletion src/keep_gpu/utilities/humanized_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,23 @@


def parse_size(text: str) -> int:
"""
Parse human-readable memory strings into float32 element counts.

The return value is the number of float32 elements needed to occupy the
requested memory size. When no unit is provided, the default unit is GB.
Supported units are the keys in `_UNITS`.
"""
text = text.strip().replace(" ", "")
m = re.fullmatch(r"([0-9]*\.?[0-9]+)([A-Za-z]*)", text)
if not m:
raise ValueError(f"invalid format: {text}, should be like '1000 MB'")
value, unit = m.groups()
unit = unit or "GB"
if len(unit) > 1:
unit = unit[:-1].upper() + unit[-1]
# Treat all-lowercase units as byte units ("gb" -> "GB", "gib" -> "GIB")
# while preserving explicit mixed-case bit forms ("Gb", "GIb").
unit = unit.upper() if unit.islower() else unit[:-1].upper() + unit[-1]
if unit not in _UNITS:
raise ValueError(f"unknown unit: {unit}, should be one of {_UNITS.keys()}")
return int(float(value) * _UNITS[unit] / 4) # float32 takes 4 bytes
11 changes: 7 additions & 4 deletions src/keep_gpu/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ def _build_console_handler(level: int) -> logging.Handler:
"""Create a colored console handler with filename:lineno."""
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
fmt = "%(log_color)s%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s"
color_fmt = "%(log_color)s%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s"
plain_fmt = (
"%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s"
)
if ColoredFormatter:
formatter = ColoredFormatter(
fmt,
color_fmt,
datefmt="%H:%M:%S",
log_colors={
"DEBUG": "cyan",
Expand All @@ -43,7 +46,7 @@ def _build_console_handler(level: int) -> logging.Handler:
},
)
else:
formatter = logging.Formatter(fmt, "%H:%M:%S")
formatter = logging.Formatter(plain_fmt, "%H:%M:%S")
handler.setFormatter(formatter)
return handler

Expand All @@ -53,7 +56,7 @@ def _build_file_handler(
) -> logging.Handler:
"""Create a file handler with filename:lineno."""
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = log_dir / f"{name}_{timestamp}.log"

Expand Down
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import torch


def pytest_addoption(parser):
Expand All @@ -23,4 +22,11 @@ def pytest_collection_modifyitems(config, items):

@pytest.fixture
def rocm_available():
return bool(torch.cuda.is_available() and getattr(torch.version, "hip", None))
try:
import torch
except Exception:
return False
try:
return bool(torch.cuda.is_available() and getattr(torch.version, "hip", None))
except Exception:
return False
24 changes: 15 additions & 9 deletions tests/cuda_controller/context_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import torch
import pytest
import torch

from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController


Expand All @@ -10,15 +11,20 @@
)
def test_cuda_controller_context_manager():
ctrl = CudaGPUController(
rank=torch.cuda.device_count() - 1, interval=10, vram_to_keep="1GB"
rank=torch.cuda.device_count() - 1,
interval=0.05,
vram_to_keep="8MB",
matmul_iterations=64,
)

torch.cuda.set_device(ctrl.rank)
before_reserved = torch.cuda.memory_reserved(ctrl.rank)
with ctrl:
print("GPU kept busy for 10 seconds.")
time.sleep(10)
print("GPU released.")
print("Test completed successfully.")

time.sleep(0.3)
assert ctrl._thread and ctrl._thread.is_alive()
during_reserved = torch.cuda.memory_reserved(ctrl.rank)
assert during_reserved >= before_reserved

if __name__ == "__main__":
test_cuda_controller_context_manager()
if ctrl._thread:
ctrl._thread.join(timeout=2)
assert not (ctrl._thread and ctrl._thread.is_alive())
19 changes: 10 additions & 9 deletions tests/cuda_controller/test_2_32pow_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_large_vram_allocation():
"""Tests controller with a large VRAM allocation."""
# Using a smaller allocation for general testing. The original 2**32 can be used on machines with sufficient VRAM.
# torch has some indexing issues on very large tensors
# e.g. tensors with more than 2**32-1 elements may cause issues
# just a test to see if it is real.
vram_elements = 2**32 # Allocates 16GiB to test large tensor handling
# Intentionally using full 2**32 float32 elements (~16 GiB) for large-tensor testing.
# Torch may expose indexing issues around this boundary on some systems.
vram_elements = 2**32
required_bytes = vram_elements * 4
free_bytes, _ = torch.cuda.mem_get_info(0)
if free_bytes < required_bytes:
pytest.skip(
f"Insufficient free VRAM for large test: need {required_bytes}, have {free_bytes}"
)

controller = CudaGPUController(
rank=0,
interval=0.5,
Expand All @@ -27,7 +32,3 @@ def test_large_vram_allocation():
assert controller._thread is not None and controller._thread.is_alive()
finally:
controller.release()


if __name__ == "__main__":
test_large_vram_allocation()
10 changes: 6 additions & 4 deletions tests/cuda_controller/test_keep_and_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ def test_cuda_controller_basic():
time.sleep(0.2)
assert ctrl._thread and ctrl._thread.is_alive()

assert ctrl._thread is not None
ctrl.release()
ctrl._thread.join(timeout=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This call to _thread.join(timeout=2) is redundant. The ctrl.release() method, called on the previous line, already blocks and waits for the thread to terminate by calling _thread.join() internally. Adding another join here is unnecessary and can be removed to simplify the test. This redundant pattern is repeated multiple times in this test file.

If there's a concern about the thread hanging, the timeout logic should be part of the release() method itself.

assert not (ctrl._thread and ctrl._thread.is_alive())

ctrl.keep()
time.sleep(0.2)
assert ctrl._thread and ctrl._thread.is_alive()
assert ctrl._thread is not None
ctrl.release()
ctrl._thread.join(timeout=2)
assert not (ctrl._thread and ctrl._thread.is_alive())

with ctrl:
assert ctrl._thread and ctrl._thread.is_alive()
time.sleep(0.2)
assert ctrl._thread is not None
ctrl._thread.join(timeout=2)
assert not (ctrl._thread and ctrl._thread.is_alive())


if __name__ == "__main__":
test_cuda_controller_basic()
4 changes: 0 additions & 4 deletions tests/global_controller/global_keep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,3 @@ def test_global_controller():
controller.release()
for ctrl in controller.controllers:
assert not (ctrl._thread and ctrl._thread.is_alive())


if __name__ == "__main__":
test_global_controller()