Skip to content
Draft
13 changes: 8 additions & 5 deletions examples/process_video_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@
background_tasks = []
background_task_started = False

# Static model warmup delay (seconds) so /health stays LOADING until complete
MODEL_LOAD_DELAY_SECONDS = 3.0

async def load_model(**kwargs):
"""Initialize processor state - called during model loading phase."""
global intensity, ready, processor

logger.info(f"load_model called with kwargs: {kwargs}")
logger.info("load_model starting (using static warmup delay)")

# Set processor variables from kwargs or use defaults
intensity = kwargs.get('intensity', 0.5)
intensity = max(0.0, min(1.0, intensity))
# Simulate synchronous warmup so /health stays LOADING during model load
if MODEL_LOAD_DELAY_SECONDS > 0:
logger.info(f"Simulating model warmup for {MODEL_LOAD_DELAY_SECONDS:.1f}s (server status should be LOADING)")
await asyncio.sleep(MODEL_LOAD_DELAY_SECONDS)

# Load the model here if needed
# model = torch.load('my_model.pth')

# Note: Cannot start background tasks here as event loop isn't running yet
# Background task will be started when first frame is processed
ready = True
logger.info(f"✅ OpenCV Green processor with horizontal flip ready (intensity: {intensity}, ready: {ready})")
Expand Down
5 changes: 5 additions & 0 deletions pytrickle/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import ErrorCallback
from .frame_processor import FrameProcessor
from .frame_skipper import AdaptiveFrameSkipper, FrameSkipConfig, FrameProcessingResult
from .state import PipelineState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
)
else:
self.frame_skipper = None


async def start(self, request_id: str = "default"):
"""Start the trickle client."""
Expand All @@ -87,6 +89,9 @@ async def start(self, request_id: str = "default"):
# Start the protocol
await self.protocol.start()

# Ensure model is loaded on the same event loop/thread before processing
await self.frame_processor.ensure_model_loaded()

# Start processing loops
self.running = True

Expand Down
16 changes: 15 additions & 1 deletion pytrickle/frame_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
making it easy to integrate AI models and async pipelines with PyTrickle.
"""

import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Optional, Any, Dict, List
Expand Down Expand Up @@ -54,12 +55,25 @@ def __init__(
**init_kwargs: Additional kwargs passed to load_model() method
"""
self.error_callback = error_callback
self.state: Optional[StreamState] = None
self.state = StreamState()

# Model loading protection
self._model_load_lock = asyncio.Lock()

def attach_state(self, state: StreamState) -> None:
"""Attach a pipeline state manager and set IDLE if model already loaded."""
self.state = state

async def ensure_model_loaded(self, **kwargs):
"""Thread-safe wrapper that ensures model is loaded exactly once."""
async with self._model_load_lock:
if not self.state.get_state() == PipelineState.IDLE:
await self.load_model(**kwargs)
self.state.set_startup_complete()
logger.debug(f"Model loaded for {self.__class__.__name__}")
else:
logger.debug(f"Model already loaded for {self.__class__.__name__}")

@abstractmethod
async def load_model(self, **kwargs):
"""
Expand Down
4 changes: 1 addition & 3 deletions pytrickle/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,7 @@ async def start_server(self):
site = web.TCPSite(runner, self.host, self.port)
await site.start()

# Set pipeline ready when server is up and ready to accept requests
self.state.set_state(PipelineState.IDLE)
self.state.set_startup_complete()
# Do not mark as IDLE here; remain LOADING until model load completes

logger.info(f"Server started on {self.host}:{self.port}")
return runner
Expand Down
54 changes: 53 additions & 1 deletion pytrickle/stream_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .frame_processor import FrameProcessor
from .server import StreamServer
from .frame_skipper import FrameSkipConfig
from .state import PipelineState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +63,9 @@ def __init__(
self.frame_skip_config = frame_skip_config
self.server_kwargs = server_kwargs

# Track background tasks to prevent memory leaks
self._background_tasks = set()

# Create internal frame processor
self._frame_processor = _InternalFrameProcessor(
video_processor=video_processor,
Expand All @@ -79,6 +83,40 @@ def __init__(
frame_skip_config=frame_skip_config,
**server_kwargs
)

# Ensure state coherence: attach server state to processor for health transitions
try:
self._frame_processor.attach_state(self.server.state)
except Exception:
# If attach fails for any reason, log and continue (non-fatal)
logger.warning("Failed to attach server state to frame processor")

# Register server startup hook to preload model on same event loop
async def _on_startup(_app):
try:
task = asyncio.create_task(self._preload_model_background())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
except Exception as e:
logger.error(f"Failed to schedule background preload: {e}")

try:
self.server.app.on_startup.append(_on_startup)
except Exception as e:
logger.error(f"Failed to register startup hook: {e}")

async def _preload_model_background(self):
"""Background model preloading with proper error handling."""
try:

# Use the thread-safe wrapper
await self._frame_processor.ensure_model_loaded()

logger.info(f"StreamProcessor '{self.name}' model preloaded on server startup")

except Exception as e:
self._frame_processor.state.set_error(str(e))
logger.error(f"Error preloading model on startup: {e}")

async def send_data(self, data: str):
"""Send data to the server."""
Expand Down Expand Up @@ -123,9 +161,23 @@ async def send_frame(self, frame: Union[VideoFrame, AudioFrame]):
logger.error(f"Error sending frame: {e}")
return False

async def cleanup(self):
"""Cancel all background tasks to prevent memory leaks."""
if self._background_tasks:
logger.info(f"Cancelling {len(self._background_tasks)} background tasks")
for task in self._background_tasks.copy():
if not task.done():
task.cancel()
# Wait for tasks to complete cancellation
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()

async def run_forever(self):
"""Run the stream processor server forever."""
await self.server.run_forever()
try:
await self.server.run_forever()
finally:
await self.cleanup()

def run(self):
"""Run the stream processor server (blocking)."""
Expand Down
7 changes: 4 additions & 3 deletions tests/test_state_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ def test_frame_processor_state_attachment(self):
processor = MockFrameProcessor()
state = StreamState()

# Initially no state attached
assert not hasattr(processor, 'state') or processor.state is None
# Initially has default state
assert hasattr(processor, 'state')
assert processor.state is not None

# Attach state
# Attach new state
processor.attach_state(state)
assert processor.state is state

Expand Down