Skip to content

Commit 68d2030

Browse files
committed
Adds option to override the file path to the resource usage pickle file
1 parent 15309e2 commit 68d2030

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/gpu_tracker/tracker.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class State(enum.Enum):
297297
def __init__(
298298
self, sleep_time: float = 1.0, ram_unit: str = 'gigabytes', gpu_ram_unit: str = 'gigabytes', time_unit: str = 'hours',
299299
n_expected_cores: int = None, gpu_uuids: set[str] = None, disable_logs: bool = False, process_id: int = None,
300-
n_join_attempts: int = 5, join_timeout: float = 10.0):
300+
resource_usage_file: str | None = None, n_join_attempts: int = 5, join_timeout: float = 10.0):
301301
"""
302302
:param sleep_time: The number of seconds to sleep in between usage-collection iterations.
303303
:param ram_unit: One of 'bytes', 'kilobytes', 'megabytes', 'gigabytes', or 'terabytes'.
@@ -307,6 +307,7 @@ def __init__(
307307
:param gpu_uuids: The UUIDs of the GPUs to track utilization for. The length of this set is used as the denominator when calculating the hardware percentages of the GPU utilization (i.e. n_expected_gpus). Defaults to all the GPUs in the system.
308308
:param disable_logs: If set, warnings are suppressed during tracking. Otherwise, the Tracker logs warnings as usual.
309309
:param process_id: The ID of the process to track. Defaults to the current process.
310+
:param resource_usage_file: The file path to the pickle file containing the ``resource_usage`` attribute. This file is automatically deleted and the ``resource_usage`` attribute is set in memory if the tracking successfully completes. But if the tracking is interrupted, the tracking information will be saved in this file as a backup. Defaults to a randomly generated file name in the current working directory of the format ``.gpu-tracker_<random UUID>.pkl``.
310311
:param n_join_attempts: The number of times the tracker attempts to join its underlying sub-process.
311312
:param join_timeout: The amount of time the tracker waits for its underlying sub-process to join.
312313
:raises ValueError: Raised if invalid units are provided.
@@ -319,7 +320,7 @@ def __init__(
319320
legit_child_ids = {process.pid for process in current_process.children()}
320321
self._stop_event = mproc.Event()
321322
extraneous_ids = {process.pid for process in current_process.children()} - legit_child_ids
322-
self._resource_usage_file = f'.gpu-tracker_{uuid.uuid1()}.pkl'
323+
self._resource_usage_file = f'.gpu-tracker_{uuid.uuid1()}.pkl' if resource_usage_file is None else resource_usage_file
323324
self._tracking_process = _TrackingProcess(
324325
self._stop_event, sleep_time, ram_unit, gpu_ram_unit, time_unit, n_expected_cores, gpu_uuids, disable_logs,
325326
process_id if process_id is not None else current_process_id, self._resource_usage_file, extraneous_ids)

tests/test_tracker.py

+15
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,18 @@ def test_state(mocker):
313313
with pt.raises(RuntimeError) as error:
314314
tracker.stop()
315315
assert str(error.value) == 'Cannot stop tracking when tracking has already stopped.'
316+
317+
318+
def test_resource_usage_file(mocker):
319+
class EventMock:
320+
@staticmethod
321+
def is_set() -> bool:
322+
return True
323+
324+
mocker.patch('gpu_tracker.tracker.mproc.Event', wraps=EventMock)
325+
file_path = 'my-file.pkl'
326+
tracker = gput.Tracker(resource_usage_file=file_path)
327+
assert not os.path.isfile(file_path)
328+
tracker._tracking_process.run()
329+
assert os.path.isfile(file_path)
330+
os.remove(file_path)

0 commit comments

Comments
 (0)