Skip to content

Commit e58d519

Browse files
authored
fix: multiprocess broke local development (#430)
* restored worker_state.py to pre-multiprocess * refactor: JobsProgress persists from pickled file .runpod_jobs.pkl * fix: added and updated tests for pkl-persisted JobsProgress * feat: worker simulation to rehearse how jobs scheduling + heartbeat perform
1 parent 6edf3a0 commit e58d519

File tree

15 files changed

+556
-187
lines changed

15 files changed

+556
-187
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,4 @@ dmypy.json
137137
# Pyre type checker
138138
.pyre/
139139
runpod/_version.py
140+
.runpod_jobs.pkl

runpod/serverless/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from . import worker
1717
from .modules import rp_fastapi
1818
from .modules.rp_logger import RunPodLogger
19-
from .modules.rp_progress import progress_update
2019

2120
log = RunPodLogger()
2221

runpod/serverless/modules/rp_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dic
149149
job_result["stopPod"] = True
150150

151151
# If rp_debugger is set, debugger output will be returned.
152-
if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict):
152+
if config.get("rp_args", {}).get("rp_debugger", False) and isinstance(job_result, dict):
153153
job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output()
154154
log.debug("rp_debugger | Flag set, returning debugger output.", job["id"])
155155

runpod/serverless/modules/rp_ping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from runpod.version import __version__ as runpod_version
1616

1717
log = RunPodLogger()
18-
jobs = JobsProgress() # Contains the list of jobs that are currently running.
1918

2019

2120
class Heartbeat:
@@ -97,6 +96,7 @@ def _send_ping(self):
9796
"""
9897
Sends a heartbeat to the Runpod server.
9998
"""
99+
jobs = JobsProgress() # Get the singleton instance
100100
job_ids = jobs.get_job_list()
101101
ping_params = {"job_id": job_ids, "runpod_version": runpod_version}
102102

runpod/serverless/modules/rp_scale.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .worker_state import JobsProgress, IS_LOCAL_TEST
1616

1717
log = RunPodLogger()
18-
job_progress = JobsProgress()
1918

2019

2120
def _handle_uncaught_exception(exc_type, exc_value, exc_traceback):
@@ -47,6 +46,7 @@ def __init__(self, config: Dict[str, Any]):
4746
self._shutdown_event = asyncio.Event()
4847
self.current_concurrency = 1
4948
self.config = config
49+
self.job_progress = JobsProgress() # Cache the singleton instance
5050

5151
self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)
5252

@@ -149,7 +149,7 @@ def kill_worker(self):
149149

150150
def current_occupancy(self) -> int:
151151
current_queue_count = self.jobs_queue.qsize()
152-
current_progress_count = job_progress.get_job_count()
152+
current_progress_count = self.job_progress.get_job_count()
153153

154154
log.debug(
155155
f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}"
@@ -188,7 +188,7 @@ async def get_jobs(self, session: ClientSession):
188188

189189
for job in acquired_jobs:
190190
await self.jobs_queue.put(job)
191-
job_progress.add(job)
191+
self.job_progress.add(job)
192192
log.debug("Job Queued", job["id"])
193193

194194
log.info(f"Jobs in queue: {self.jobs_queue.qsize()}")
@@ -268,6 +268,6 @@ async def handle_job(self, session: ClientSession, job: dict):
268268
self.jobs_queue.task_done()
269269

270270
# Job is no longer in progress
271-
job_progress.remove(job)
271+
self.job_progress.remove(job)
272272

273273
log.debug("Finished Job", job["id"])

runpod/serverless/modules/worker_state.py

Lines changed: 113 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import os
66
import time
77
import uuid
8-
from multiprocessing import Manager
9-
from multiprocessing.managers import SyncManager
10-
from typing import Any, Dict, Optional
8+
import pickle
9+
import fcntl
10+
import tempfile
11+
from typing import Any, Dict, Optional, Set
1112

1213
from .rp_logger import RunPodLogger
1314

@@ -63,149 +64,150 @@ def __str__(self) -> str:
6364
# ---------------------------------------------------------------------------- #
6465
# Tracker #
6566
# ---------------------------------------------------------------------------- #
66-
class JobsProgress:
67-
"""Track the state of current jobs in progress using shared memory."""
68-
69-
_instance: Optional['JobsProgress'] = None
70-
_manager: SyncManager
71-
_shared_data: Any
72-
_lock: Any
67+
class JobsProgress(Set[Job]):
68+
"""Track the state of current jobs in progress with persistent state."""
69+
70+
_instance = None
71+
_STATE_DIR = os.getcwd()
72+
_STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs.pkl")
7373

7474
def __new__(cls):
75-
if cls._instance is None:
76-
instance = object.__new__(cls)
77-
# Initialize instance variables
78-
instance._manager = Manager()
79-
instance._shared_data = instance._manager.dict()
80-
instance._shared_data['jobs'] = instance._manager.list()
81-
instance._lock = instance._manager.Lock()
82-
cls._instance = instance
83-
return cls._instance
75+
if JobsProgress._instance is None:
76+
os.makedirs(cls._STATE_DIR, exist_ok=True)
77+
JobsProgress._instance = set.__new__(cls)
78+
# Initialize as empty set before loading state
79+
set.__init__(JobsProgress._instance)
80+
JobsProgress._instance._load_state()
81+
return JobsProgress._instance
8482

8583
def __init__(self):
86-
# Everything is already initialized in __new__
84+
# This should never clear data in a singleton
85+
# Don't call parent __init__ as it would clear the set
8786
pass
88-
87+
8988
def __repr__(self) -> str:
9089
return f"<{self.__class__.__name__}>: {self.get_job_list()}"
9190

91+
def _load_state(self):
92+
"""Load jobs state from pickle file with file locking."""
93+
try:
94+
if (
95+
os.path.exists(self._STATE_FILE)
96+
and os.path.getsize(self._STATE_FILE) > 0
97+
):
98+
with open(self._STATE_FILE, "rb") as f:
99+
fcntl.flock(f, fcntl.LOCK_SH)
100+
try:
101+
loaded_jobs = pickle.load(f)
102+
# Clear current state and add loaded jobs
103+
super().clear()
104+
for job in loaded_jobs:
105+
set.add(
106+
self, job
107+
) # Use set.add to avoid triggering _save_state
108+
109+
except (EOFError, pickle.UnpicklingError):
110+
# Handle empty or corrupted file
111+
log.debug(
112+
"JobsProgress: Failed to load state file, starting with empty state"
113+
)
114+
pass
115+
finally:
116+
fcntl.flock(f, fcntl.LOCK_UN)
117+
118+
except FileNotFoundError:
119+
log.debug("JobsProgress: No state file found, starting with empty state")
120+
pass
121+
122+
def _save_state(self):
123+
"""Save jobs state to pickle file with atomic write and file locking."""
124+
try:
125+
# Use temporary file for atomic write
126+
with tempfile.NamedTemporaryFile(
127+
dir=self._STATE_DIR, delete=False, mode="wb"
128+
) as temp_f:
129+
fcntl.flock(temp_f, fcntl.LOCK_EX)
130+
try:
131+
pickle.dump(set(self), temp_f)
132+
finally:
133+
fcntl.flock(temp_f, fcntl.LOCK_UN)
134+
135+
# Atomically replace the state file
136+
os.replace(temp_f.name, self._STATE_FILE)
137+
except Exception as e:
138+
log.error(f"Failed to save job state: {e}")
139+
92140
def clear(self) -> None:
93-
with self._lock:
94-
self._shared_data['jobs'][:] = []
141+
super().clear()
142+
self._save_state()
95143

96144
def add(self, element: Any):
97145
"""
98146
Adds a Job object to the set.
99-
"""
100-
if isinstance(element, str):
101-
job_dict = {'id': element}
102-
elif isinstance(element, dict):
103-
job_dict = element
104-
elif hasattr(element, 'id'):
105-
job_dict = {'id': element.id}
106-
else:
107-
raise TypeError("Only Job objects can be added to JobsProgress.")
108147
109-
with self._lock:
110-
# Check if job already exists
111-
job_list = self._shared_data['jobs']
112-
for existing_job in job_list:
113-
if existing_job['id'] == job_dict['id']:
114-
return # Job already exists
115-
116-
# Add new job
117-
job_list.append(job_dict)
118-
log.debug(f"JobsProgress | Added job: {job_dict['id']}")
119-
120-
def get(self, element: Any) -> Optional[Job]:
121-
"""
122-
Retrieves a Job object from the set.
148+
If the added element is a string, then `Job(id=element)` is added
123149
124-
If the element is a string, searches for Job with that id.
150+
If the added element is a dict, that `Job(**element)` is added
125151
"""
126152
if isinstance(element, str):
127-
search_id = element
128-
elif isinstance(element, Job):
129-
search_id = element.id
130-
else:
131-
raise TypeError("Only Job objects can be retrieved from JobsProgress.")
153+
element = Job(id=element)
132154

133-
with self._lock:
134-
for job_dict in self._shared_data['jobs']:
135-
if job_dict['id'] == search_id:
136-
log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}")
137-
return Job(**job_dict)
138-
139-
return None
155+
if isinstance(element, dict):
156+
element = Job(**element)
157+
158+
if not isinstance(element, Job):
159+
raise TypeError("Only Job objects can be added to JobsProgress.")
160+
161+
result = super().add(element)
162+
self._save_state()
163+
return result
140164

141165
def remove(self, element: Any):
142166
"""
143167
Removes a Job object from the set.
168+
169+
If the element is a string, then `Job(id=element)` is removed
170+
171+
If the element is a dict, then `Job(**element)` is removed
144172
"""
145173
if isinstance(element, str):
146-
job_id = element
147-
elif isinstance(element, dict):
148-
job_id = element.get('id')
149-
elif hasattr(element, 'id'):
150-
job_id = element.id
151-
else:
174+
element = Job(id=element)
175+
176+
if isinstance(element, dict):
177+
element = Job(**element)
178+
179+
if not isinstance(element, Job):
152180
raise TypeError("Only Job objects can be removed from JobsProgress.")
153181

154-
with self._lock:
155-
job_list = self._shared_data['jobs']
156-
# Find and remove the job
157-
for i, job_dict in enumerate(job_list):
158-
if job_dict['id'] == job_id:
159-
del job_list[i]
160-
log.debug(f"JobsProgress | Removed job: {job_dict['id']}")
161-
break
182+
result = super().discard(element)
183+
self._save_state()
184+
return result
185+
186+
def get(self, element: Any) -> Optional[Job]:
187+
if isinstance(element, str):
188+
element = Job(id=element)
189+
190+
if not isinstance(element, Job):
191+
raise TypeError("Only Job objects can be retrieved from JobsProgress.")
192+
193+
for job in self:
194+
if job == element:
195+
return job
196+
return None
162197

163198
def get_job_list(self) -> Optional[str]:
164199
"""
165200
Returns the list of job IDs as comma-separated string.
166201
"""
167-
with self._lock:
168-
job_list = list(self._shared_data['jobs'])
169-
170-
if not job_list:
202+
self._load_state()
203+
204+
if not len(self):
171205
return None
172206

173-
log.debug(f"JobsProgress | Jobs in progress: {job_list}")
174-
return ",".join(str(job_dict['id']) for job_dict in job_list)
207+
return ",".join(str(job) for job in self)
175208

176209
def get_job_count(self) -> int:
177210
"""
178211
Returns the number of jobs.
179212
"""
180-
with self._lock:
181-
return len(self._shared_data['jobs'])
182-
183-
def __iter__(self):
184-
"""Make the class iterable - returns Job objects"""
185-
with self._lock:
186-
# Create a snapshot of jobs to avoid holding lock during iteration
187-
job_dicts = list(self._shared_data['jobs'])
188-
189-
# Return an iterator of Job objects
190-
return iter(Job(**job_dict) for job_dict in job_dicts)
191-
192-
def __len__(self):
193-
"""Support len() operation"""
194-
return self.get_job_count()
195-
196-
def __contains__(self, element: Any) -> bool:
197-
"""Support 'in' operator"""
198-
if isinstance(element, str):
199-
search_id = element
200-
elif isinstance(element, Job):
201-
search_id = element.id
202-
elif isinstance(element, dict):
203-
search_id = element.get('id')
204-
else:
205-
return False
206-
207-
with self._lock:
208-
for job_dict in self._shared_data['jobs']:
209-
if job_dict['id'] == search_id:
210-
return True
211-
return False
213+
return len(self)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
RUNPOD_AI_API_KEY=XXX
2+
RUNPOD_API_URL=http://localhost:8080/graphql
3+
RUNPOD_DEBUG_LEVEL=INFO
4+
RUNPOD_ENDPOINT_ID=test-endpoint
5+
RUNPOD_PING_INTERVAL=1000
6+
RUNPOD_POD_ID=test-worker
7+
RUNPOD_WEBHOOK_GET_JOB=http://localhost:8080/v2/test-endpoint/job-take/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
8+
RUNPOD_WEBHOOK_PING=http://localhost:8080/v2/test-endpoint/ping/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
9+
RUNPOD_WEBHOOK_POST_OUTPUT=http://localhost:8080/v2/test-endpoint/job-done/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
10+
RUNPOD_WEBHOOK_JOB_STREAM=http://localhost:8080/v2/test-endpoint/job-stream/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.PHONY: localhost worker all
2+
3+
all: localhost worker
4+
5+
localhost:
6+
python localhost.py &
7+
8+
worker:
9+
python worker.py
10+
11+
clean:
12+
find . -type f -name ".runpod_jobs.pkl" -delete
13+
find . -type f -name "*.pyc" -delete
14+
find . -type d -name "__pycache__" -delete

0 commit comments

Comments
 (0)