diff --git a/src/gpu_tracker/tracker.py b/src/gpu_tracker/tracker.py index 724582a..1b88462 100644 --- a/src/gpu_tracker/tracker.py +++ b/src/gpu_tracker/tracker.py @@ -297,7 +297,7 @@ class State(enum.Enum): def __init__( self, sleep_time: float = 1.0, ram_unit: str = 'gigabytes', gpu_ram_unit: str = 'gigabytes', time_unit: str = 'hours', n_expected_cores: int = None, gpu_uuids: set[str] = None, disable_logs: bool = False, process_id: int = None, - n_join_attempts: int = 5, join_timeout: float = 10.0): + resource_usage_file: str | None = None, n_join_attempts: int = 5, join_timeout: float = 10.0): """ :param sleep_time: The number of seconds to sleep in between usage-collection iterations. :param ram_unit: One of 'bytes', 'kilobytes', 'megabytes', 'gigabytes', or 'terabytes'. @@ -307,6 +307,7 @@ def __init__( :param gpu_uuids: The UUIDs of the GPUs to track utilization for. The length of this set is used as the denominator when calculating the hardware percentages of the GPU utilization (i.e. n_expected_gpus). Defaults to all the GPUs in the system. :param disable_logs: If set, warnings are suppressed during tracking. Otherwise, the Tracker logs warnings as usual. :param process_id: The ID of the process to track. Defaults to the current process. + :param resource_usage_file: The file path to the pickle file containing the ``resource_usage`` attribute. This file is automatically deleted and the ``resource_usage`` attribute is set in memory if the tracking successfully completes. But if the tracking is interrupted, the tracking information will be saved in this file as a backup. Defaults to a randomly generated file name in the current working directory of the format ``.gpu-tracker_.pkl``. :param n_join_attempts: The number of times the tracker attempts to join its underlying sub-process. :param join_timeout: The amount of time the tracker waits for its underlying sub-process to join. :raises ValueError: Raised if invalid units are provided. @@ -319,7 +320,7 @@ def __init__( legit_child_ids = {process.pid for process in current_process.children()} self._stop_event = mproc.Event() extraneous_ids = {process.pid for process in current_process.children()} - legit_child_ids - self._resource_usage_file = f'.gpu-tracker_{uuid.uuid1()}.pkl' + self._resource_usage_file = f'.gpu-tracker_{uuid.uuid1()}.pkl' if resource_usage_file is None else resource_usage_file self._tracking_process = _TrackingProcess( self._stop_event, sleep_time, ram_unit, gpu_ram_unit, time_unit, n_expected_cores, gpu_uuids, disable_logs, process_id if process_id is not None else current_process_id, self._resource_usage_file, extraneous_ids) diff --git a/tests/test_tracker.py b/tests/test_tracker.py index 6d923ed..4c8cc30 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -313,3 +313,18 @@ def test_state(mocker): with pt.raises(RuntimeError) as error: tracker.stop() assert str(error.value) == 'Cannot stop tracking when tracking has already stopped.' + + +def test_resource_usage_file(mocker): + class EventMock: + @staticmethod + def is_set() -> bool: + return True + + mocker.patch('gpu_tracker.tracker.mproc.Event', wraps=EventMock) + file_path = 'my-file.pkl' + tracker = gput.Tracker(resource_usage_file=file_path) + assert not os.path.isfile(file_path) + tracker._tracking_process.run() + assert os.path.isfile(file_path) + os.remove(file_path)