Skip to content
44 changes: 22 additions & 22 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, config: Dict[str, Any]):
self.config = config
self.job_progress = JobsProgress() # Cache the singleton instance

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)
self.jobs_queue = asyncio.Queue()

self.concurrency_modifier = _default_concurrency_modifier
self.jobs_fetcher = get_job
Expand All @@ -72,9 +72,9 @@ def __init__(self, config: Dict[str, Any]):
self.jobs_handler = jobs_handler

async def set_scale(self):
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)
new_concurrency = self.concurrency_modifier(self.current_concurrency)

if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize):
if new_concurrency == self.current_concurrency:
# no need to resize
return

Expand All @@ -83,10 +83,8 @@ async def set_scale(self):
await asyncio.sleep(1)
continue

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)
log.debug(
f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}"
)
self.current_concurrency = new_concurrency
log.debug(f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}")

def start(self):
"""
Expand Down Expand Up @@ -188,7 +186,6 @@ async def get_jobs(self, session: ClientSession):

for job in acquired_jobs:
await self.jobs_queue.put(job)
self.job_progress.add(job)
log.debug("Job Queued", job["id"])

log.info(f"Jobs in queue: {self.jobs_queue.qsize()}")
Expand All @@ -211,38 +208,41 @@ async def get_jobs(self, session: ClientSession):
)
finally:
# Yield control back to the event loop
await asyncio.sleep(0)
await asyncio.sleep(0.1)

async def run_jobs(self, session: ClientSession):
"""
Retrieve jobs from the jobs queue and process them concurrently.

Runs the block in an infinite loop while the worker is alive or jobs queue is not empty.
"""
tasks = [] # Store the tasks for concurrent job processing
tasks: set[asyncio.Task] = set() # Store the tasks for concurrent job processing

while self.is_alive() or not self.jobs_queue.empty():
while self.is_alive() or not self.jobs_queue.empty() or tasks:
# Fetch as many jobs as the concurrency allows
while len(tasks) < self.current_concurrency and not self.jobs_queue.empty():
# log.debug(f"About to get a job from the queue. Queue size: {self.jobs_queue.qsize()}")
job = await self.jobs_queue.get()
self.job_progress.add(job)
log.debug(f"Dequeued job {job['id']}, now running. Queue size: {self.jobs_queue.qsize()}")

# Create a new task for each job and add it to the task list
task = asyncio.create_task(self.handle_job(session, job))
tasks.append(task)
tasks.add(task)

# Wait for any job to finish
# 2. If jobs are running, wait a little for completions
if tasks:
log.info(f"Jobs in progress: {len(tasks)}")

# Wait for at least one task to finish
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
tasks,
timeout=0.1,
return_when=asyncio.FIRST_COMPLETED,
)

# Remove completed tasks from the list
tasks = [t for t in tasks if t not in done]

# Yield control back to the event loop
await asyncio.sleep(0)
# Remove completed tasks
tasks.difference_update(done)
else:
# Nothing running — don’t spin CPU
await asyncio.sleep(0.5)

# Ensure all remaining tasks finish before stopping
await asyncio.gather(*tasks)
Expand Down
Loading