diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 5c7d79cc..e9890cf3 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -48,7 +48,7 @@ def __init__(self, config: Dict[str, Any]): self.config = config self.job_progress = JobsProgress() # Cache the singleton instance - self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) + self.jobs_queue = asyncio.Queue() self.concurrency_modifier = _default_concurrency_modifier self.jobs_fetcher = get_job @@ -72,9 +72,9 @@ def __init__(self, config: Dict[str, Any]): self.jobs_handler = jobs_handler async def set_scale(self): - self.current_concurrency = self.concurrency_modifier(self.current_concurrency) + new_concurrency = self.concurrency_modifier(self.current_concurrency) - if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize): + if new_concurrency == self.current_concurrency: # no need to resize return @@ -83,10 +83,8 @@ async def set_scale(self): await asyncio.sleep(1) continue - self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) - log.debug( - f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}" - ) + self.current_concurrency = new_concurrency + log.debug(f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}") def start(self): """ @@ -188,7 +186,6 @@ async def get_jobs(self, session: ClientSession): for job in acquired_jobs: await self.jobs_queue.put(job) - self.job_progress.add(job) log.debug("Job Queued", job["id"]) log.info(f"Jobs in queue: {self.jobs_queue.qsize()}") @@ -211,7 +208,7 @@ async def get_jobs(self, session: ClientSession): ) finally: # Yield control back to the event loop - await asyncio.sleep(0) + await asyncio.sleep(0.1) async def run_jobs(self, session: ClientSession): """ @@ -219,30 +216,33 @@ async def run_jobs(self, session: ClientSession): Runs the block in an infinite loop while the worker is alive or jobs queue is not empty. """ - tasks = [] # Store the tasks for concurrent job processing + tasks: set[asyncio.Task] = set() # Store the tasks for concurrent job processing - while self.is_alive() or not self.jobs_queue.empty(): + while self.is_alive() or not self.jobs_queue.empty() or tasks: # Fetch as many jobs as the concurrency allows while len(tasks) < self.current_concurrency and not self.jobs_queue.empty(): + # log.debug(f"About to get a job from the queue. Queue size: {self.jobs_queue.qsize()}") job = await self.jobs_queue.get() + self.job_progress.add(job) + log.debug(f"Dequeued job {job['id']}, now running. Queue size: {self.jobs_queue.qsize()}") # Create a new task for each job and add it to the task list task = asyncio.create_task(self.handle_job(session, job)) - tasks.append(task) + tasks.add(task) - # Wait for any job to finish + # 2. If jobs are running, wait a little for completions if tasks: - log.info(f"Jobs in progress: {len(tasks)}") - + # Wait for at least one task to finish done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED + tasks, + timeout=0.1, + return_when=asyncio.FIRST_COMPLETED, ) - - # Remove completed tasks from the list - tasks = [t for t in tasks if t not in done] - - # Yield control back to the event loop - await asyncio.sleep(0) + # Remove completed tasks + tasks.difference_update(done) + else: + # Nothing running — don’t spin CPU + await asyncio.sleep(0.5) # Ensure all remaining tasks finish before stopping await asyncio.gather(*tasks)