|
5 | 5 | import os
|
6 | 6 | import time
|
7 | 7 | import uuid
|
8 |
| -from multiprocessing import Manager |
9 |
| -from multiprocessing.managers import SyncManager |
10 |
| -from typing import Any, Dict, Optional |
| 8 | +import pickle |
| 9 | +import fcntl |
| 10 | +import tempfile |
| 11 | +from typing import Any, Dict, Optional, Set |
11 | 12 |
|
12 | 13 | from .rp_logger import RunPodLogger
|
13 | 14 |
|
@@ -63,149 +64,150 @@ def __str__(self) -> str:
|
63 | 64 | # ---------------------------------------------------------------------------- #
|
64 | 65 | # Tracker #
|
65 | 66 | # ---------------------------------------------------------------------------- #
|
66 |
| -class JobsProgress: |
67 |
| - """Track the state of current jobs in progress using shared memory.""" |
68 |
| - |
69 |
| - _instance: Optional['JobsProgress'] = None |
70 |
| - _manager: SyncManager |
71 |
| - _shared_data: Any |
72 |
| - _lock: Any |
| 67 | +class JobsProgress(Set[Job]): |
| 68 | + """Track the state of current jobs in progress with persistent state.""" |
| 69 | + |
| 70 | + _instance = None |
| 71 | + _STATE_DIR = os.getcwd() |
| 72 | + _STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs.pkl") |
73 | 73 |
|
74 | 74 | def __new__(cls):
|
75 |
| - if cls._instance is None: |
76 |
| - instance = object.__new__(cls) |
77 |
| - # Initialize instance variables |
78 |
| - instance._manager = Manager() |
79 |
| - instance._shared_data = instance._manager.dict() |
80 |
| - instance._shared_data['jobs'] = instance._manager.list() |
81 |
| - instance._lock = instance._manager.Lock() |
82 |
| - cls._instance = instance |
83 |
| - return cls._instance |
| 75 | + if JobsProgress._instance is None: |
| 76 | + os.makedirs(cls._STATE_DIR, exist_ok=True) |
| 77 | + JobsProgress._instance = set.__new__(cls) |
| 78 | + # Initialize as empty set before loading state |
| 79 | + set.__init__(JobsProgress._instance) |
| 80 | + JobsProgress._instance._load_state() |
| 81 | + return JobsProgress._instance |
84 | 82 |
|
85 | 83 | def __init__(self):
|
86 |
| - # Everything is already initialized in __new__ |
| 84 | + # This should never clear data in a singleton |
| 85 | + # Don't call parent __init__ as it would clear the set |
87 | 86 | pass
|
88 |
| - |
| 87 | + |
89 | 88 | def __repr__(self) -> str:
|
90 | 89 | return f"<{self.__class__.__name__}>: {self.get_job_list()}"
|
91 | 90 |
|
| 91 | + def _load_state(self): |
| 92 | + """Load jobs state from pickle file with file locking.""" |
| 93 | + try: |
| 94 | + if ( |
| 95 | + os.path.exists(self._STATE_FILE) |
| 96 | + and os.path.getsize(self._STATE_FILE) > 0 |
| 97 | + ): |
| 98 | + with open(self._STATE_FILE, "rb") as f: |
| 99 | + fcntl.flock(f, fcntl.LOCK_SH) |
| 100 | + try: |
| 101 | + loaded_jobs = pickle.load(f) |
| 102 | + # Clear current state and add loaded jobs |
| 103 | + super().clear() |
| 104 | + for job in loaded_jobs: |
| 105 | + set.add( |
| 106 | + self, job |
| 107 | + ) # Use set.add to avoid triggering _save_state |
| 108 | + |
| 109 | + except (EOFError, pickle.UnpicklingError): |
| 110 | + # Handle empty or corrupted file |
| 111 | + log.debug( |
| 112 | + "JobsProgress: Failed to load state file, starting with empty state" |
| 113 | + ) |
| 114 | + pass |
| 115 | + finally: |
| 116 | + fcntl.flock(f, fcntl.LOCK_UN) |
| 117 | + |
| 118 | + except FileNotFoundError: |
| 119 | + log.debug("JobsProgress: No state file found, starting with empty state") |
| 120 | + pass |
| 121 | + |
| 122 | + def _save_state(self): |
| 123 | + """Save jobs state to pickle file with atomic write and file locking.""" |
| 124 | + try: |
| 125 | + # Use temporary file for atomic write |
| 126 | + with tempfile.NamedTemporaryFile( |
| 127 | + dir=self._STATE_DIR, delete=False, mode="wb" |
| 128 | + ) as temp_f: |
| 129 | + fcntl.flock(temp_f, fcntl.LOCK_EX) |
| 130 | + try: |
| 131 | + pickle.dump(set(self), temp_f) |
| 132 | + finally: |
| 133 | + fcntl.flock(temp_f, fcntl.LOCK_UN) |
| 134 | + |
| 135 | + # Atomically replace the state file |
| 136 | + os.replace(temp_f.name, self._STATE_FILE) |
| 137 | + except Exception as e: |
| 138 | + log.error(f"Failed to save job state: {e}") |
| 139 | + |
92 | 140 | def clear(self) -> None:
|
93 |
| - with self._lock: |
94 |
| - self._shared_data['jobs'][:] = [] |
| 141 | + super().clear() |
| 142 | + self._save_state() |
95 | 143 |
|
96 | 144 | def add(self, element: Any):
|
97 | 145 | """
|
98 | 146 | Adds a Job object to the set.
|
99 |
| - """ |
100 |
| - if isinstance(element, str): |
101 |
| - job_dict = {'id': element} |
102 |
| - elif isinstance(element, dict): |
103 |
| - job_dict = element |
104 |
| - elif hasattr(element, 'id'): |
105 |
| - job_dict = {'id': element.id} |
106 |
| - else: |
107 |
| - raise TypeError("Only Job objects can be added to JobsProgress.") |
108 | 147 |
|
109 |
| - with self._lock: |
110 |
| - # Check if job already exists |
111 |
| - job_list = self._shared_data['jobs'] |
112 |
| - for existing_job in job_list: |
113 |
| - if existing_job['id'] == job_dict['id']: |
114 |
| - return # Job already exists |
115 |
| - |
116 |
| - # Add new job |
117 |
| - job_list.append(job_dict) |
118 |
| - log.debug(f"JobsProgress | Added job: {job_dict['id']}") |
119 |
| - |
120 |
| - def get(self, element: Any) -> Optional[Job]: |
121 |
| - """ |
122 |
| - Retrieves a Job object from the set. |
| 148 | + If the added element is a string, then `Job(id=element)` is added |
123 | 149 |
|
124 |
| - If the element is a string, searches for Job with that id. |
| 150 | + If the added element is a dict, that `Job(**element)` is added |
125 | 151 | """
|
126 | 152 | if isinstance(element, str):
|
127 |
| - search_id = element |
128 |
| - elif isinstance(element, Job): |
129 |
| - search_id = element.id |
130 |
| - else: |
131 |
| - raise TypeError("Only Job objects can be retrieved from JobsProgress.") |
| 153 | + element = Job(id=element) |
132 | 154 |
|
133 |
| - with self._lock: |
134 |
| - for job_dict in self._shared_data['jobs']: |
135 |
| - if job_dict['id'] == search_id: |
136 |
| - log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}") |
137 |
| - return Job(**job_dict) |
138 |
| - |
139 |
| - return None |
| 155 | + if isinstance(element, dict): |
| 156 | + element = Job(**element) |
| 157 | + |
| 158 | + if not isinstance(element, Job): |
| 159 | + raise TypeError("Only Job objects can be added to JobsProgress.") |
| 160 | + |
| 161 | + result = super().add(element) |
| 162 | + self._save_state() |
| 163 | + return result |
140 | 164 |
|
141 | 165 | def remove(self, element: Any):
|
142 | 166 | """
|
143 | 167 | Removes a Job object from the set.
|
| 168 | +
|
| 169 | + If the element is a string, then `Job(id=element)` is removed |
| 170 | + |
| 171 | + If the element is a dict, then `Job(**element)` is removed |
144 | 172 | """
|
145 | 173 | if isinstance(element, str):
|
146 |
| - job_id = element |
147 |
| - elif isinstance(element, dict): |
148 |
| - job_id = element.get('id') |
149 |
| - elif hasattr(element, 'id'): |
150 |
| - job_id = element.id |
151 |
| - else: |
| 174 | + element = Job(id=element) |
| 175 | + |
| 176 | + if isinstance(element, dict): |
| 177 | + element = Job(**element) |
| 178 | + |
| 179 | + if not isinstance(element, Job): |
152 | 180 | raise TypeError("Only Job objects can be removed from JobsProgress.")
|
153 | 181 |
|
154 |
| - with self._lock: |
155 |
| - job_list = self._shared_data['jobs'] |
156 |
| - # Find and remove the job |
157 |
| - for i, job_dict in enumerate(job_list): |
158 |
| - if job_dict['id'] == job_id: |
159 |
| - del job_list[i] |
160 |
| - log.debug(f"JobsProgress | Removed job: {job_dict['id']}") |
161 |
| - break |
| 182 | + result = super().discard(element) |
| 183 | + self._save_state() |
| 184 | + return result |
| 185 | + |
| 186 | + def get(self, element: Any) -> Optional[Job]: |
| 187 | + if isinstance(element, str): |
| 188 | + element = Job(id=element) |
| 189 | + |
| 190 | + if not isinstance(element, Job): |
| 191 | + raise TypeError("Only Job objects can be retrieved from JobsProgress.") |
| 192 | + |
| 193 | + for job in self: |
| 194 | + if job == element: |
| 195 | + return job |
| 196 | + return None |
162 | 197 |
|
163 | 198 | def get_job_list(self) -> Optional[str]:
|
164 | 199 | """
|
165 | 200 | Returns the list of job IDs as comma-separated string.
|
166 | 201 | """
|
167 |
| - with self._lock: |
168 |
| - job_list = list(self._shared_data['jobs']) |
169 |
| - |
170 |
| - if not job_list: |
| 202 | + self._load_state() |
| 203 | + |
| 204 | + if not len(self): |
171 | 205 | return None
|
172 | 206 |
|
173 |
| - log.debug(f"JobsProgress | Jobs in progress: {job_list}") |
174 |
| - return ",".join(str(job_dict['id']) for job_dict in job_list) |
| 207 | + return ",".join(str(job) for job in self) |
175 | 208 |
|
176 | 209 | def get_job_count(self) -> int:
|
177 | 210 | """
|
178 | 211 | Returns the number of jobs.
|
179 | 212 | """
|
180 |
| - with self._lock: |
181 |
| - return len(self._shared_data['jobs']) |
182 |
| - |
183 |
| - def __iter__(self): |
184 |
| - """Make the class iterable - returns Job objects""" |
185 |
| - with self._lock: |
186 |
| - # Create a snapshot of jobs to avoid holding lock during iteration |
187 |
| - job_dicts = list(self._shared_data['jobs']) |
188 |
| - |
189 |
| - # Return an iterator of Job objects |
190 |
| - return iter(Job(**job_dict) for job_dict in job_dicts) |
191 |
| - |
192 |
| - def __len__(self): |
193 |
| - """Support len() operation""" |
194 |
| - return self.get_job_count() |
195 |
| - |
196 |
| - def __contains__(self, element: Any) -> bool: |
197 |
| - """Support 'in' operator""" |
198 |
| - if isinstance(element, str): |
199 |
| - search_id = element |
200 |
| - elif isinstance(element, Job): |
201 |
| - search_id = element.id |
202 |
| - elif isinstance(element, dict): |
203 |
| - search_id = element.get('id') |
204 |
| - else: |
205 |
| - return False |
206 |
| - |
207 |
| - with self._lock: |
208 |
| - for job_dict in self._shared_data['jobs']: |
209 |
| - if job_dict['id'] == search_id: |
210 |
| - return True |
211 |
| - return False |
| 213 | + return len(self) |
0 commit comments