Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ version = "0.1.6"
license = { file = "LICENSE" }
dependencies = [
"asyncio",
"pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.4",
"pytrickle @ git+https://github.com/livepeer/pytrickle.git@3d0a4a36d62f76aa9779265ecb40dc7cffc0fc18",
"comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba",
"aiortc",
"aiohttp",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
asyncio
pytrickle @ git+https://github.com/livepeer/pytrickle.git@v0.1.4
pytrickle @ git+https://github.com/livepeer/pytrickle.git@3d0a4a36d62f76aa9779265ecb40dc7cffc0fc18
comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@e62df3a8811d8c652a195d4669f4fb27f6c9a9ba
aiortc
aiohttp
Expand Down
7 changes: 1 addition & 6 deletions server/byoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def force_print(*args, **kwargs):
# Set the stream processor reference for text data publishing
frame_processor.set_stream_processor(processor)

# Create async startup function to load model
async def load_model_on_startup(app):
await processor._frame_processor.load_model()

# Create async startup function for orchestrator registration
async def register_orchestrator_startup(app):
try:
Expand Down Expand Up @@ -171,8 +167,7 @@ async def register_orchestrator_startup(app):
except Exception as e:
logger.error(f"Orchestrator registration failed: {e}")

# Add model loading and registration to startup hooks
processor.server.app.on_startup.append(load_model_on_startup)
# Add registration to startup hooks
processor.server.app.on_startup.append(register_orchestrator_startup)

# Add warmup endpoint: accepts same body as prompts update
Expand Down
48 changes: 43 additions & 5 deletions server/frame_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytrickle.frame_processor import FrameProcessor
from pytrickle.frames import VideoFrame, AudioFrame
from comfystream.pipeline import Pipeline
from comfystream.utils import convert_prompt, ComfyStreamParamsUpdateRequest
from comfystream.utils import convert_prompt, ComfyStreamParamsUpdateRequest, get_default_workflow

logger = logging.getLogger(__name__)

Expand All @@ -35,6 +35,7 @@ def __init__(self, text_poll_interval: float = 0.25, **load_params):
self._text_forward_task = None
self._background_tasks = []
self._stop_event = asyncio.Event()
self._runner_active = False
super().__init__()

def set_stream_processor(self, stream_processor):
Expand Down Expand Up @@ -113,13 +114,14 @@ async def on_stream_stop(self):
# Set stop event to signal all background tasks to stop
self._stop_event.set()

# Stop the ComfyStream client's prompt execution
# Stop the ComfyStream client's prompt execution immediately to avoid no-input logs
if self.pipeline and self.pipeline.client:
logger.info("Stopping ComfyStream client prompt execution")
try:
await self.pipeline.client.cleanup()
await self.pipeline.client.stop_prompts_immediately()
except Exception as e:
logger.error(f"Error stopping ComfyStream client: {e}")
self._runner_active = False

# Stop text forwarder
await self._stop_text_forwarder()
Expand Down Expand Up @@ -150,9 +152,9 @@ def _reset_stop_event(self):
self._stop_event.clear()

async def load_model(self, **kwargs):
"""Load model and initialize the pipeline."""
"""Load model, initialize pipeline, set default workflow once, and warm up."""
params = {**self._load_params, **kwargs}

if self.pipeline is None:
self.pipeline = Pipeline(
width=int(params.get('width', 512)),
Expand All @@ -165,6 +167,22 @@ async def load_model(self, **kwargs):
blacklist_nodes=["ComfyUI-Manager"]
)

# Only set the default workflow if no prompts are currently configured
has_prompts = False
try:
has_prompts = bool(getattr(self.pipeline.client, "current_prompts", []))
except Exception:
has_prompts = False

if not has_prompts:
default_workflow = get_default_workflow()
# Apply default prompt first (starts prompt task), then perform warmup synchronously
await self.update_params({"prompts": default_workflow})
await self.warmup()
else:
# Prompts exist; perform warmup synchronously
await self.warmup()

async def warmup(self):
"""Warm up the pipeline."""
if not self.pipeline:
Expand All @@ -173,6 +191,10 @@ async def warmup(self):

logger.info("Running pipeline warmup...")
try:
# Ensure runner exists and is enabled for warmup
await self.pipeline.client.ensure_prompt_tasks_running()
self.pipeline.client.resume()

capabilities = self.pipeline.get_workflow_io_capabilities()
logger.info(f"Detected I/O capabilities: {capabilities}")

Expand All @@ -184,6 +206,12 @@ async def warmup(self):

except Exception as e:
logger.error(f"Warmup failed: {e}")
finally:
# Pause prompt loop after warmup; will resume on first real input
try:
self.pipeline.client.pause()
except Exception:
logger.debug("Failed to stop prompt loop after warmup", exc_info=True)

def _schedule_warmup(self) -> None:
"""Schedule warmup in background if not already running."""
Expand All @@ -200,6 +228,11 @@ def _schedule_warmup(self) -> None:
async def process_video_async(self, frame: VideoFrame) -> VideoFrame:
"""Process video frame through ComfyStream Pipeline."""
try:
# On first frame of an active stream, start/resume runner
if not self._runner_active and self.pipeline and self.pipeline.client:
await self.pipeline.client.ensure_prompt_tasks_running()
self.pipeline.client.resume()
self._runner_active = True

# Convert pytrickle VideoFrame to av.VideoFrame
av_frame = frame.to_av_frame(frame.tensor)
Expand All @@ -223,6 +256,11 @@ async def process_audio_async(self, frame: AudioFrame) -> List[AudioFrame]:
try:
if not self.pipeline:
return [frame]
# On first frame of an active stream, start/resume runner
if not self._runner_active and self.pipeline and self.pipeline.client:
await self.pipeline.client.ensure_prompt_tasks_running()
self.pipeline.client.resume()
self._runner_active = True

# Audio processing needed - use pipeline
av_frame = frame.to_av_frame()
Expand Down
118 changes: 80 additions & 38 deletions src/comfystream/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import List
import logging
from typing import List
import contextlib

from comfystream import tensor_cache
from comfystream.utils import convert_prompt
Expand All @@ -17,12 +18,17 @@ class ComfyStreamClient:
def __init__(self, max_workers: int = 1, **kwargs):
config = Configuration(**kwargs)
self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers)
self.running_prompts = {} # To be used for cancelling tasks
self.running_prompts = {}
self.current_prompts = []
self._cleanup_lock = asyncio.Lock()
self._prompt_update_lock = asyncio.Lock()
self._stop_event = asyncio.Event()

# PromptRunner state
self._shutdown_event = asyncio.Event()
self._run_enabled_event = asyncio.Event()
self._runner_task = None

async def set_prompts(self, prompts: List[PromptDictInput]):
"""Set new prompts, replacing any existing ones.

Expand All @@ -36,15 +42,15 @@ async def set_prompts(self, prompts: List[PromptDictInput]):
if not prompts:
raise ValueError("Cannot set empty prompts list")

# Cancel existing prompts first to avoid conflicts
await self.cancel_running_prompts()
# Reset stop event for new prompts
self._stop_event.clear()
# Pause runner while swapping prompts to avoid interleaving
was_running = self._run_enabled_event.is_set()
self._run_enabled_event.clear()
self.current_prompts = [convert_prompt(prompt) for prompt in prompts]
logger.info(f"Queuing {len(self.current_prompts)} prompt(s) for execution")
for idx in range(len(self.current_prompts)):
task = asyncio.create_task(self.run_prompt(idx))
self.running_prompts[idx] = task
logger.info(f"Configured {len(self.current_prompts)} prompt(s)")
# Ensure runner exists (IDLE until resumed)
await self.ensure_prompt_tasks_running()
if was_running:
self._run_enabled_event.set()

async def update_prompts(self, prompts: List[PromptDictInput]):
async with self._prompt_update_lock:
Expand All @@ -57,34 +63,61 @@ async def update_prompts(self, prompts: List[PromptDictInput]):
for idx, prompt in enumerate(prompts):
converted_prompt = convert_prompt(prompt)
try:
# Lightweight validation by queueing is retained for compatibility
await self.comfy_client.queue_prompt(converted_prompt)
self.current_prompts[idx] = converted_prompt
except Exception as e:
raise Exception(f"Prompt update failed: {str(e)}") from e

async def run_prompt(self, prompt_index: int):
while not self._stop_event.is_set():
async with self._prompt_update_lock:
try:
await self.comfy_client.queue_prompt(self.current_prompts[prompt_index])
except asyncio.CancelledError:
raise
except ComfyStreamInputTimeoutError:
# Timeout errors are expected during stream switching - just continue
logger.info(f"Input for prompt {prompt_index} timed out, continuing")
continue
except Exception as e:
await self.cleanup()
logger.error(f"Error running prompt: {str(e)}")
raise
async def ensure_prompt_tasks_running(self):
# Ensure the single runner task exists (does not force running)
if self._runner_task and not self._runner_task.done():
return
if not self.current_prompts:
return
self._shutdown_event.clear()
self._runner_task = asyncio.create_task(self._runner_loop())

async def _runner_loop(self):
try:
while not self._shutdown_event.is_set():
# IDLE until running is enabled
await self._run_enabled_event.wait()
# Snapshot prompts without holding the lock during network I/O
async with self._prompt_update_lock:
prompts_snapshot = list(self.current_prompts)
for prompt_index, prompt in enumerate(prompts_snapshot):
if self._shutdown_event.is_set() or not self._run_enabled_event.is_set():
break
try:
await self.comfy_client.queue_prompt(prompt)
except asyncio.CancelledError:
raise
except ComfyStreamInputTimeoutError:
logger.info(f"Input for prompt {prompt_index} timed out, continuing")
continue
except Exception as e:
logger.error(f"Error running prompt: {str(e)}")
await asyncio.sleep(0.05)
continue
except asyncio.CancelledError:
pass

async def cleanup(self):
# Set stop event to signal prompt loops to exit
# Signal runner to shutdown
self._stop_event.set()

await self.cancel_running_prompts()
self._shutdown_event.set()
if self._runner_task:
self._runner_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._runner_task
self._runner_task = None

# Pause running
self._run_enabled_event.clear()

async with self._cleanup_lock:
if self.comfy_client.is_running:
if getattr(self.comfy_client, "is_running", False):
try:
await self.comfy_client.__aexit__()
except Exception as e:
Expand All @@ -94,15 +127,8 @@ async def cleanup(self):
logger.info("Client cleanup complete")

async def cancel_running_prompts(self):
async with self._cleanup_lock:
tasks_to_cancel = list(self.running_prompts.values())
for task in tasks_to_cancel:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.running_prompts.clear()
"""Compatibility: pause the runner without destroying it."""
self._run_enabled_event.clear()


async def cleanup_queues(self):
Expand All @@ -121,6 +147,22 @@ async def cleanup_queues(self):
while not tensor_cache.text_outputs.empty():
await tensor_cache.text_outputs.get()

# Explicit lifecycle helpers for external controllers (FrameProcessor)
def resume(self):
self._run_enabled_event.set()

def pause(self):
self._run_enabled_event.clear()

async def stop_prompts_immediately(self):
"""Cancel the runner task to immediately stop any in-flight prompt execution."""
self._run_enabled_event.clear()
if self._runner_task:
self._runner_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._runner_task
self._runner_task = None

def put_video_input(self, frame):
if tensor_cache.image_inputs.full():
tensor_cache.image_inputs.get(block=True)
Expand Down