-
Notifications
You must be signed in to change notification settings - Fork 5
[tests, cuda] test: strengthen VRAM target and release assertions for keep sessions #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
9692d71
7d1b46c
6551fa1
56a3951
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,19 @@ | ||
| import time | ||
| import torch | ||
| import pytest | ||
| import torch | ||
|
|
||
| from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController | ||
|
|
||
|
|
||
| def _wait_until(predicate, timeout_s: float = 3.0, interval_s: float = 0.05) -> bool: | ||
| deadline = time.time() + timeout_s | ||
| while time.time() < deadline: | ||
| if predicate(): | ||
| return True | ||
| time.sleep(interval_s) | ||
| return False | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not torch.cuda.is_available(), | ||
| reason="Only run CUDA tests when CUDA is available", | ||
|
|
@@ -32,3 +42,59 @@ def test_cuda_controller_basic(): | |
| assert ctrl._thread and ctrl._thread.is_alive() | ||
| time.sleep(0.2) | ||
| assert not (ctrl._thread and ctrl._thread.is_alive()) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not torch.cuda.is_available(), | ||
| reason="Only run CUDA tests when CUDA is available", | ||
| ) | ||
| def test_cuda_controller_respects_vram_target_during_keep(): | ||
| """Ensure keep() consumes roughly requested VRAM and release() frees it.""" | ||
| ctrl = CudaGPUController( | ||
| rank=0, | ||
| interval=0.05, | ||
| vram_to_keep="32MB", | ||
| relu_iterations=32, | ||
| ) | ||
| torch.cuda.set_device(ctrl.rank) | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.synchronize() | ||
|
|
||
| target_bytes = int(ctrl.vram_to_keep) * 4 | ||
| free_bytes, _ = torch.cuda.mem_get_info(ctrl.rank) | ||
| if free_bytes < int(target_bytes * 1.2): | ||
| pytest.skip( | ||
| f"Insufficient free VRAM for assertion test: need ~{target_bytes}, have {free_bytes}" | ||
| ) | ||
|
|
||
| before_alloc = torch.cuda.memory_allocated(ctrl.rank) | ||
| before_reserved = torch.cuda.memory_reserved(ctrl.rank) | ||
| alloc_tolerance = 8 * 1024 * 1024 | ||
| reserve_tolerance = 16 * 1024 * 1024 | ||
|
|
||
| ctrl.keep() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| reached = _wait_until( | ||
| lambda: ( | ||
| max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) | ||
| >= int(target_bytes * 0.95) | ||
Wangmerlyn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ), | ||
| timeout_s=3.0, | ||
| ) | ||
| assert reached, "keep() did not reach expected VRAM allocation target in time" | ||
|
|
||
| alloc_delta = max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) | ||
| reserved_delta = max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) | ||
| assert alloc_delta >= int(target_bytes * 0.95) | ||
| assert reserved_delta >= alloc_delta | ||
|
|
||
| ctrl.release() | ||
| released = _wait_until( | ||
| lambda: ( | ||
| max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) | ||
| <= alloc_tolerance | ||
| and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) | ||
| <= reserve_tolerance | ||
| ), | ||
| timeout_s=3.0, | ||
| ) | ||
| assert released, "VRAM did not return near baseline after release()" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This polling loop is very similar to the release-checking loop below (lines 69-81). Furthermore, in
test_keep_and_release.py, you've introduced a_wait_untilhelper function which abstracts this exact polling pattern.To improve consistency and reduce code duplication across the test suite, consider refactoring this test to also use a polling helper. You could move
_wait_untilto a shared test utility file and import it in both test files.This would make the test logic cleaner and easier to maintain.