diff --git a/pyproject.toml b/pyproject.toml index ecf846f7..31c906a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ version = "0.1.6" license = { file = "LICENSE" } dependencies = [ "asyncio", - "pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.4", + "pytrickle @ git+https://github.com/livepeer/pytrickle.git@3d0a4a36d62f76aa9779265ecb40dc7cffc0fc18", "comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba", "aiortc", "aiohttp", diff --git a/requirements.txt b/requirements.txt index b35bb52c..7352978b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ asyncio -pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.4 +pytrickle @ git+https://github.com/livepeer/pytrickle.git@3d0a4a36d62f76aa9779265ecb40dc7cffc0fc18 comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba aiortc aiohttp diff --git a/server/byoc.py b/server/byoc.py index 7a941815..74022168 100644 --- a/server/byoc.py +++ b/server/byoc.py @@ -139,10 +139,6 @@ def force_print(*args, **kwargs): # Set the stream processor reference for text data publishing frame_processor.set_stream_processor(processor) - # Create async startup function to load model - async def load_model_on_startup(app): - await processor._frame_processor.load_model() - # Create async startup function for orchestrator registration async def register_orchestrator_startup(app): try: @@ -171,8 +167,7 @@ async def register_orchestrator_startup(app): except Exception as e: logger.error(f"Orchestrator registration failed: {e}") - # Add model loading and registration to startup hooks - processor.server.app.on_startup.append(load_model_on_startup) + # Add registration to startup hooks processor.server.app.on_startup.append(register_orchestrator_startup) # Add warmup endpoint: accepts same body as prompts update diff --git a/server/frame_processor.py b/server/frame_processor.py index 39272313..823c5a41 100644 --- a/server/frame_processor.py +++ b/server/frame_processor.py @@ -8,7 +8,7 @@ from pytrickle.frame_processor import FrameProcessor from pytrickle.frames import VideoFrame, AudioFrame from comfystream.pipeline import Pipeline -from comfystream.utils import convert_prompt, ComfyStreamParamsUpdateRequest +from comfystream.utils import convert_prompt, ComfyStreamParamsUpdateRequest, get_default_workflow logger = logging.getLogger(__name__) @@ -35,6 +35,7 @@ def __init__(self, text_poll_interval: float = 0.25, **load_params): self._text_forward_task = None self._background_tasks = [] self._stop_event = asyncio.Event() + self._runner_active = False super().__init__() def set_stream_processor(self, stream_processor): @@ -113,13 +114,14 @@ async def on_stream_stop(self): # Set stop event to signal all background tasks to stop self._stop_event.set() - # Stop the ComfyStream client's prompt execution + # Stop the ComfyStream client's prompt execution immediately to avoid no-input logs if self.pipeline and self.pipeline.client: logger.info("Stopping ComfyStream client prompt execution") try: - await self.pipeline.client.cleanup() + await self.pipeline.client.stop_prompts_immediately() except Exception as e: logger.error(f"Error stopping ComfyStream client: {e}") + self._runner_active = False # Stop text forwarder await self._stop_text_forwarder() @@ -150,9 +152,9 @@ def _reset_stop_event(self): self._stop_event.clear() async def load_model(self, **kwargs): - """Load model and initialize the pipeline.""" + """Load model, initialize pipeline, set default workflow once, and warm up.""" params = {**self._load_params, **kwargs} - + if self.pipeline is None: self.pipeline = Pipeline( width=int(params.get('width', 512)), @@ -165,6 +167,22 @@ async def load_model(self, **kwargs): blacklist_nodes=["ComfyUI-Manager"] ) + # Only set the default workflow if no prompts are currently configured + has_prompts = False + try: + has_prompts = bool(getattr(self.pipeline.client, "current_prompts", [])) + except Exception: + has_prompts = False + + if not has_prompts: + default_workflow = get_default_workflow() + # Apply default prompt first (starts prompt task), then perform warmup synchronously + await self.update_params({"prompts": default_workflow}) + await self.warmup() + else: + # Prompts exist; perform warmup synchronously + await self.warmup() + async def warmup(self): """Warm up the pipeline.""" if not self.pipeline: @@ -173,6 +191,10 @@ async def warmup(self): logger.info("Running pipeline warmup...") try: + # Ensure runner exists and is enabled for warmup + await self.pipeline.client.ensure_prompt_tasks_running() + self.pipeline.client.resume() + capabilities = self.pipeline.get_workflow_io_capabilities() logger.info(f"Detected I/O capabilities: {capabilities}") @@ -184,6 +206,12 @@ async def warmup(self): except Exception as e: logger.error(f"Warmup failed: {e}") + finally: + # Pause prompt loop after warmup; will resume on first real input + try: + self.pipeline.client.pause() + except Exception: + logger.debug("Failed to stop prompt loop after warmup", exc_info=True) def _schedule_warmup(self) -> None: """Schedule warmup in background if not already running.""" @@ -200,6 +228,11 @@ def _schedule_warmup(self) -> None: async def process_video_async(self, frame: VideoFrame) -> VideoFrame: """Process video frame through ComfyStream Pipeline.""" try: + # On first frame of an active stream, start/resume runner + if not self._runner_active and self.pipeline and self.pipeline.client: + await self.pipeline.client.ensure_prompt_tasks_running() + self.pipeline.client.resume() + self._runner_active = True # Convert pytrickle VideoFrame to av.VideoFrame av_frame = frame.to_av_frame(frame.tensor) @@ -223,6 +256,11 @@ async def process_audio_async(self, frame: AudioFrame) -> List[AudioFrame]: try: if not self.pipeline: return [frame] + # On first frame of an active stream, start/resume runner + if not self._runner_active and self.pipeline and self.pipeline.client: + await self.pipeline.client.ensure_prompt_tasks_running() + self.pipeline.client.resume() + self._runner_active = True # Audio processing needed - use pipeline av_frame = frame.to_av_frame() diff --git a/src/comfystream/client.py b/src/comfystream/client.py index 7686fca1..29547451 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,6 +1,7 @@ import asyncio -from typing import List import logging +from typing import List +import contextlib from comfystream import tensor_cache from comfystream.utils import convert_prompt @@ -17,12 +18,17 @@ class ComfyStreamClient: def __init__(self, max_workers: int = 1, **kwargs): config = Configuration(**kwargs) self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - self.running_prompts = {} # To be used for cancelling tasks + self.running_prompts = {} self.current_prompts = [] self._cleanup_lock = asyncio.Lock() self._prompt_update_lock = asyncio.Lock() self._stop_event = asyncio.Event() + # PromptRunner state + self._shutdown_event = asyncio.Event() + self._run_enabled_event = asyncio.Event() + self._runner_task = None + async def set_prompts(self, prompts: List[PromptDictInput]): """Set new prompts, replacing any existing ones. @@ -36,15 +42,15 @@ async def set_prompts(self, prompts: List[PromptDictInput]): if not prompts: raise ValueError("Cannot set empty prompts list") - # Cancel existing prompts first to avoid conflicts - await self.cancel_running_prompts() - # Reset stop event for new prompts - self._stop_event.clear() + # Pause runner while swapping prompts to avoid interleaving + was_running = self._run_enabled_event.is_set() + self._run_enabled_event.clear() self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - logger.info(f"Queuing {len(self.current_prompts)} prompt(s) for execution") - for idx in range(len(self.current_prompts)): - task = asyncio.create_task(self.run_prompt(idx)) - self.running_prompts[idx] = task + logger.info(f"Configured {len(self.current_prompts)} prompt(s)") + # Ensure runner exists (IDLE until resumed) + await self.ensure_prompt_tasks_running() + if was_running: + self._run_enabled_event.set() async def update_prompts(self, prompts: List[PromptDictInput]): async with self._prompt_update_lock: @@ -57,34 +63,61 @@ async def update_prompts(self, prompts: List[PromptDictInput]): for idx, prompt in enumerate(prompts): converted_prompt = convert_prompt(prompt) try: + # Lightweight validation by queueing is retained for compatibility await self.comfy_client.queue_prompt(converted_prompt) self.current_prompts[idx] = converted_prompt except Exception as e: raise Exception(f"Prompt update failed: {str(e)}") from e - async def run_prompt(self, prompt_index: int): - while not self._stop_event.is_set(): - async with self._prompt_update_lock: - try: - await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) - except asyncio.CancelledError: - raise - except ComfyStreamInputTimeoutError: - # Timeout errors are expected during stream switching - just continue - logger.info(f"Input for prompt {prompt_index} timed out, continuing") - continue - except Exception as e: - await self.cleanup() - logger.error(f"Error running prompt: {str(e)}") - raise + async def ensure_prompt_tasks_running(self): + # Ensure the single runner task exists (does not force running) + if self._runner_task and not self._runner_task.done(): + return + if not self.current_prompts: + return + self._shutdown_event.clear() + self._runner_task = asyncio.create_task(self._runner_loop()) + + async def _runner_loop(self): + try: + while not self._shutdown_event.is_set(): + # IDLE until running is enabled + await self._run_enabled_event.wait() + # Snapshot prompts without holding the lock during network I/O + async with self._prompt_update_lock: + prompts_snapshot = list(self.current_prompts) + for prompt_index, prompt in enumerate(prompts_snapshot): + if self._shutdown_event.is_set() or not self._run_enabled_event.is_set(): + break + try: + await self.comfy_client.queue_prompt(prompt) + except asyncio.CancelledError: + raise + except ComfyStreamInputTimeoutError: + logger.info(f"Input for prompt {prompt_index} timed out, continuing") + continue + except Exception as e: + logger.error(f"Error running prompt: {str(e)}") + await asyncio.sleep(0.05) + continue + except asyncio.CancelledError: + pass async def cleanup(self): - # Set stop event to signal prompt loops to exit + # Signal runner to shutdown self._stop_event.set() - - await self.cancel_running_prompts() + self._shutdown_event.set() + if self._runner_task: + self._runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._runner_task + self._runner_task = None + + # Pause running + self._run_enabled_event.clear() + async with self._cleanup_lock: - if self.comfy_client.is_running: + if getattr(self.comfy_client, "is_running", False): try: await self.comfy_client.__aexit__() except Exception as e: @@ -94,15 +127,8 @@ async def cleanup(self): logger.info("Client cleanup complete") async def cancel_running_prompts(self): - async with self._cleanup_lock: - tasks_to_cancel = list(self.running_prompts.values()) - for task in tasks_to_cancel: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.running_prompts.clear() + """Compatibility: pause the runner without destroying it.""" + self._run_enabled_event.clear() async def cleanup_queues(self): @@ -121,6 +147,22 @@ async def cleanup_queues(self): while not tensor_cache.text_outputs.empty(): await tensor_cache.text_outputs.get() + # Explicit lifecycle helpers for external controllers (FrameProcessor) + def resume(self): + self._run_enabled_event.set() + + def pause(self): + self._run_enabled_event.clear() + + async def stop_prompts_immediately(self): + """Cancel the runner task to immediately stop any in-flight prompt execution.""" + self._run_enabled_event.clear() + if self._runner_task: + self._runner_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._runner_task + self._runner_task = None + def put_video_input(self, frame): if tensor_cache.image_inputs.full(): tensor_cache.image_inputs.get(block=True)