diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 020a7170..becfa80a 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -1393,6 +1393,72 @@ async def reload_models(is_admin: bool = Depends(require_admin)): raise HTTPException(status_code=500, detail=message) +@router.get("/api/restart-status") +async def get_restart_status(is_admin: bool = Depends(require_admin)): + """Get engine restart status and memory diagnostics.""" + from datetime import datetime, timezone + engine_pool = _get_engine_pool() + if engine_pool is None: + raise HTTPException(status_code=503, detail="Server not initialized") + + import mlx.core as mx + active = mx.get_active_memory() + peak = mx.get_peak_memory() + cache = mx.get_cache_memory() + effective_active = active - cache + + enforcer = getattr(engine_pool, '_process_memory_enforcer', None) + watermark_str = "unknown" + utilization_pct = 0.0 + limit_gb = 0.0 + if enforcer and hasattr(enforcer, '_max_bytes') and enforcer._max_bytes > 0: + utilization_pct = round(active / enforcer._max_bytes * 100, 1) + limit_gb = round(enforcer._max_bytes / 1024**3, 2) + from ..process_memory_enforcer import MemoryWatermark + watermark_str = MemoryWatermark.from_utilization(active / enforcer._max_bytes).value + + return { + "restart_requested": engine_pool.restart_requested, + "restart_reason": engine_pool.restart_reason, + "memory": { + "active_gb": round(active / 1024**3, 2), + "peak_gb": round(peak / 1024**3, 2), + "cache_gb": round(cache / 1024**3, 2), + "effective_active_gb": round(effective_active / 1024**3, 2), + "model_est_gb": round(engine_pool.current_model_memory / 1024**3, 2), + "loaded_models": engine_pool.loaded_model_count, + "watermark": watermark_str, + "utilization_pct": utilization_pct, + "limit_gb": limit_gb, + "loaded_model_details": engine_pool.get_loaded_model_details(), + }, + "last_eviction": engine_pool.last_eviction, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + + +@router.post("/api/restart-engine") +async def restart_engine(is_admin: bool = Depends(require_admin)): + """Request or clear engine restart flag.""" + engine_pool = _get_engine_pool() + if engine_pool is None: + raise HTTPException(status_code=503, detail="Server not initialized") + + was_requested = engine_pool.restart_requested + reason = engine_pool.restart_reason + + if was_requested: + logger.warning(f"Restart requested: {reason}") + engine_pool.clear_restart_request() + + return { + "status": "ok", + "restart_was_requested": was_requested, + "reason": reason, + "message": "Restart flag cleared. For actual restart, terminate and restart omlx.", + } + + @router.put("/api/models/{model_id}/settings") async def update_model_settings( model_id: str, diff --git a/omlx/engine/batched.py b/omlx/engine/batched.py index 7620180a..dbc32cae 100644 --- a/omlx/engine/batched.py +++ b/omlx/engine/batched.py @@ -279,7 +279,11 @@ async def stop(self) -> None: """Stop the engine and cleanup resources.""" if self._engine: await self._engine.stop() - self._engine.engine.close() + if hasattr(self._engine, 'engine') and self._engine.engine is not None: + try: + self._engine.engine.close() + except Exception as e: + logger.warning(f"Error closing engine: {e}") self._engine = None self._model = None self._tokenizer = None diff --git a/omlx/engine_pool.py b/omlx/engine_pool.py index 3b0936a1..33360b28 100644 --- a/omlx/engine_pool.py +++ b/omlx/engine_pool.py @@ -95,6 +95,9 @@ def __init__( self._process_memory_enforcer: object | None = None # Set by server self._settings_manager: object | None = None # Set by server self._suppress_ttl: bool = False # Suppress TTL during benchmarks + self._restart_requested: bool = False + self._restart_reason: str = "" + self._last_eviction: dict | None = None @property def max_model_memory(self) -> int | None: @@ -116,6 +119,52 @@ def loaded_model_count(self) -> int: """Number of currently loaded models.""" return sum(1 for e in self._entries.values() if e.engine is not None) + @property + def restart_requested(self) -> bool: + """True if memory barrier timed out and restart is needed.""" + return self._restart_requested + + @property + def restart_reason(self) -> str: + """Reason for the last restart request.""" + return self._restart_reason + + @property + def last_eviction(self) -> dict | None: + """Last eviction event details.""" + return self._last_eviction + + def get_loaded_model_details(self) -> list[dict]: + """Get details of all currently loaded models.""" + now = time.time() + details = [] + for mid, e in self._entries.items(): + if e.engine is not None: + has_active = False + try: + has_active = e.engine.has_active_requests() + except AttributeError: + pass + details.append({ + "id": mid, + "est_gb": round(e.estimated_size / 1024**3, 2), + "last_access_ago_s": round(now - e.last_access, 1) if e.last_access > 0 else None, + "is_pinned": e.is_pinned, + "engine_type": e.engine_type, + "active_requests": has_active, + }) + return details + + def request_restart(self, reason: str = "manual") -> None: + """Request engine restart.""" + self._restart_requested = True + self._restart_reason = reason + + def clear_restart_request(self) -> None: + """Clear restart request after restart is completed.""" + self._restart_requested = False + self._restart_reason = "" + def discover_models( self, model_dirs: str | list[str], pinned_models: list[str] | None = None ) -> None: @@ -403,6 +452,113 @@ async def get_engine( ), ) + # Pre-load watermark check with hot-cache LRU eviction + if self._process_memory_enforcer is not None: + enforcer = self._process_memory_enforcer + if hasattr(enforcer, 'pre_load_check'): + action, diagnostics = await enforcer.pre_load_check( + entry.estimated_size, engine_type=entry.engine_type, + ) + watermark_before = diagnostics.get('watermark', 'unknown') + logger.info( + f"Pre-load check: model={model_id} " + f"watermark={watermark_before} " + f"projected={diagnostics.get('projected_gb', '?')}GB " + f"effective={diagnostics.get('effective_current_gb', '?')}GB " + f"overhead={diagnostics.get('overhead_pct', '?')}% " + f"action={action.value}" + ) + + evicted_models = [] + freed_est_gb = 0.0 + + # Phase 1: Evict LRU non-active models to lower watermark + if action.value in ( + "reclaim_then_load", "restart_then_load", "queue_and_wait", + ): + candidates_info = [] + now = time.time() + for mid, e in self._entries.items(): + if e.engine is None or e.is_pinned: + continue + has_active = False + try: + has_active = e.engine.has_active_requests() + except AttributeError: + pass + candidates_info.append({ + "model": mid, + "est_gb": round(e.estimated_size / 1024**3, 1), + "idle_s": round(now - e.last_access, 0) if e.last_access > 0 else None, + "active_reqs": has_active, + "evictable": not has_active, + }) + logger.info(f"Eviction candidates: {candidates_info}") + + while True: + victim = self._find_lru_victim() + if victim is None: + break + victim_entry = self._entries.get(victim) + victim_gb = round( + victim_entry.estimated_size / 1024**3, 2, + ) if victim_entry else 0 + logger.info( + f"Evicting LRU model '{victim}' " + f"(est={victim_gb}GB) to lower watermark" + ) + await self._unload_engine(victim) + evicted_models.append(victim) + freed_est_gb += victim_gb + from datetime import datetime, timezone + self._last_eviction = { + "model": victim, + "reason": f"watermark_{watermark_before}", + "freed_est_gb": victim_gb, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + action, diagnostics = await enforcer.pre_load_check( + entry.estimated_size, engine_type=entry.engine_type, + ) + new_wm = diagnostics.get('watermark', 'unknown') + logger.info( + f"After evicting '{victim}': " + f"watermark={new_wm} " + f"projected={diagnostics.get('projected_gb', '?')}GB" + ) + if action.value in ("load_directly", "reclaim_then_load"): + break + + # Phase 2: Lightweight emergency reclaim if still elevated + if action.value in ( + "reclaim_then_load", "restart_then_load", "queue_and_wait", + ): + await enforcer.emergency_reclaim() + action, diagnostics = await enforcer.pre_load_check( + entry.estimated_size, engine_type=entry.engine_type, + ) + + # Phase 3: Flag restart only if still red/fatal after all efforts + watermark_after = diagnostics.get('watermark', 'unknown') + if action.value in ("restart_then_load", "queue_and_wait"): + self._restart_requested = True + self._restart_reason = ( + f"Watermark {watermark_after.upper()} for model {model_id}: " + f"projected={diagnostics.get('projected_gb', '?')}GB" + ) + logger.error( + f"RESTART RECOMMENDED before loading {model_id}" + ) + + logger.info( + f"Pre-load summary: " + f"watermark_before={watermark_before} " + f"watermark_after={watermark_after} " + f"evicted={evicted_models} " + f"freed_est_gb={freed_est_gb} " + f"restart_requested={self._restart_requested}" + ) + # Now load the model await self._load_engine(model_id, force_lm=force_lm) @@ -435,18 +591,34 @@ async def _ensure_memory_available(self, required: int) -> None: ) await self._unload_engine(victim) - def _find_lru_victim(self) -> str | None: + def _find_lru_victim(self, exclude: set[str] | None = None) -> str | None: """ Find the least recently used non-pinned loaded model. + Skips models with active inference requests to avoid interrupting + in-flight generation. + + Args: + exclude: Optional set of model IDs to skip + Returns: - Model ID of the LRU victim, or None if all models are pinned + Model ID of the LRU victim, or None if no evictable model found """ - candidates = [ - (e.last_access, mid) - for mid, e in self._entries.items() - if e.engine is not None and not e.is_pinned - ] + candidates = [] + for mid, e in self._entries.items(): + if e.engine is None or e.is_pinned: + continue + if exclude and mid in exclude: + continue + try: + if e.engine.has_active_requests(): + logger.debug( + f"Skipping victim '{mid}': has active requests" + ) + continue + except AttributeError: + pass + candidates.append((e.last_access, mid)) if not candidates: return None candidates.sort() # Sort by last_access (oldest first) @@ -454,9 +626,10 @@ def _find_lru_victim(self) -> str | None: async def _unload_engine(self, model_id: str) -> None: """ - Immediately stop and unload an engine. + Immediately stop and unload an engine with memory settle barrier. - This aborts any in-progress requests. + After stopping the engine, polls mx.get_active_memory() to verify + Metal buffers are actually reclaimed before proceeding. Args: model_id: The model ID to unload @@ -466,34 +639,92 @@ async def _unload_engine(self, model_id: str) -> None: return logger.info(f"Unloading model: {model_id} (immediate abort)") + pre_unload_active = mx.get_active_memory() try: await entry.engine.stop() except Exception as e: logger.warning(f"Error stopping engine for {model_id}: {e}") - # Release memory tracking - self._current_model_memory -= entry.estimated_size - - # Clear engine reference + # Clear engine reference before settle barrier entry.engine = None entry.last_access = 0.0 - # Force garbage collection to release memory. - # Run mx.clear_cache on the global MLX executor to avoid concurrent - # Metal operations with running engines. See issue #85. - # Synchronize before clearing to prevent releasing Metal buffers - # still referenced by in-flight command buffers. See issue #300. - gc.collect() + # Force Metal memory cleanup loop = asyncio.get_running_loop() + gc.collect() await loop.run_in_executor( get_mlx_executor(), lambda: (mx.synchronize(), mx.clear_cache()) ) - logger.info( - f"Unloaded model: {model_id}, " - f"memory usage: {format_size(self._current_model_memory)}" - ) + # Memory settle barrier: verify actual freed memory + settle_tolerance = 2 * 1024**3 # 2 GB tolerance + min_expected_freed = max(0, entry.estimated_size - settle_tolerance) + settled = False + for _settle_round in range(10): + active_now = mx.get_active_memory() + actual_freed = pre_unload_active - active_now + if actual_freed >= min_expected_freed: + settled = True + logger.debug( + f"Settle round {_settle_round + 1} for '{model_id}': " + f"freed={format_size(actual_freed)} " + f"(need>={format_size(min_expected_freed)}) — settled" + ) + break + logger.debug( + f"Settle round {_settle_round + 1} for '{model_id}': " + f"freed={format_size(actual_freed)} " + f"(need>={format_size(min_expected_freed)}) — retry" + ) + await asyncio.sleep(0.5) + gc.collect() + await loop.run_in_executor( + get_mlx_executor(), lambda: (mx.synchronize(), mx.clear_cache()) + ) + + # Release memory tracking AFTER barrier + self._current_model_memory -= entry.estimated_size + + if settled: + logger.info( + f"Unloaded model: {model_id}, " + f"freed={format_size(actual_freed)} " + f"(expected>={format_size(min_expected_freed)}), " + f"active_memory: {format_size(active_now)} (settled)" + ) + else: + logger.warning( + f"Settle barrier timed out for '{model_id}': " + f"freed={format_size(actual_freed)} " + f"(need>={format_size(min_expected_freed)}), " + f"pre_unload={format_size(pre_unload_active)}, " + f"active_now={format_size(active_now)}" + ) + # Emergency reclaim: 3 more rounds of aggressive GC + clear + for _emergency_round in range(3): + gc.collect() + await loop.run_in_executor( + get_mlx_executor(), + lambda: (mx.synchronize(), mx.clear_cache()), + ) + await asyncio.sleep(1.0) + active_after_emergency = mx.get_active_memory() + if active_after_emergency > self._current_model_memory + 5 * 1024**3: + logger.error( + f"Emergency reclaim failed: active_memory={format_size(active_after_emergency)} " + f"still exceeds safe threshold. Setting restart_requested." + ) + self._restart_requested = True + self._restart_reason = ( + f"Memory barrier timeout: active={format_size(active_after_emergency)}, " + f"expected<={format_size(self._current_model_memory + 5 * 1024**3)}" + ) + else: + logger.info( + f"Emergency reclaim succeeded: " + f"active_memory={format_size(active_after_emergency)}" + ) async def _load_engine(self, model_id: str, force_lm: bool = False) -> None: """ diff --git a/omlx/process_memory_enforcer.py b/omlx/process_memory_enforcer.py index 271654c2..4121ead6 100644 --- a/omlx/process_memory_enforcer.py +++ b/omlx/process_memory_enforcer.py @@ -14,11 +14,15 @@ from __future__ import annotations import asyncio +import gc import logging +from enum import Enum from typing import TYPE_CHECKING import mlx.core as mx +from .engine_core import get_mlx_executor + if TYPE_CHECKING: from .engine_pool import EnginePool from .model_settings import ModelSettingsManager @@ -31,6 +35,32 @@ def _format_gb(b: int) -> str: return f"{b / 1024**3:.1f}GB" +class MemoryWatermark(Enum): + """Memory pressure levels for pre-load safety decisions.""" + GREEN = "green" # < 65% utilization + YELLOW = "yellow" # 65-80% utilization + RED = "red" # 80-90% utilization + FATAL = "fatal" # > 90% utilization + + @classmethod + def from_utilization(cls, utilization: float) -> "MemoryWatermark": + if utilization < 0.65: + return cls.GREEN + elif utilization < 0.80: + return cls.YELLOW + elif utilization < 0.90: + return cls.RED + else: + return cls.FATAL + + +class WatermarkAction(Enum): + LOAD_DIRECTLY = "load_directly" + RECLAIM_THEN_LOAD = "reclaim_then_load" + RESTART_THEN_LOAD = "restart_then_load" + QUEUE_AND_WAIT = "queue_and_wait" + + class ProcessMemoryEnforcer: """ Background task that enforces process-level memory limits. @@ -300,13 +330,119 @@ async def _check_and_enforce(self) -> None: def get_status(self) -> dict: """Get enforcer status for monitoring endpoints.""" current = mx.get_active_memory() if self._running else 0 + utilization = current / self._max_bytes if self._max_bytes > 0 else 0.0 return { "enabled": self._running, "max_bytes": self._max_bytes, "max_formatted": _format_gb(self._max_bytes), "current_bytes": current, "current_formatted": _format_gb(current), - "utilization": ( - current / self._max_bytes if self._max_bytes > 0 else 0.0 - ), + "utilization": utilization, + "watermark": MemoryWatermark.from_utilization(utilization).value, } + + def get_watermark_level(self) -> MemoryWatermark: + """Get current memory watermark level.""" + if self._max_bytes <= 0: + return MemoryWatermark.GREEN + current = mx.get_active_memory() + utilization = current / self._max_bytes + return MemoryWatermark.from_utilization(utilization) + + def get_memory_diagnostics(self) -> dict: + """Get comprehensive memory diagnostics.""" + current = mx.get_active_memory() + peak = mx.get_peak_memory() + cache = mx.get_cache_memory() + model_est = self._engine_pool.current_model_memory + utilization = current / self._max_bytes if self._max_bytes > 0 else 0.0 + watermark = MemoryWatermark.from_utilization(utilization) + return { + "active_gb": round(current / 1024**3, 2), + "peak_gb": round(peak / 1024**3, 2), + "cache_gb": round(cache / 1024**3, 2), + "model_est_gb": round(model_est / 1024**3, 2), + "loaded_models": self._engine_pool.loaded_model_count, + "limit_gb": round(self._max_bytes / 1024**3, 2), + "utilization_pct": round(utilization * 100, 1), + "watermark": watermark.value, + } + + async def pre_load_check( + self, new_model_size_bytes: int, engine_type: str = "batched", + ) -> tuple: + """ + Pre-load memory safety check. + + Deducts reclaimable Metal cache from current usage and scales + runtime overhead by engine type for more accurate projections. + + Returns (WatermarkAction, diagnostics_dict). + """ + current_active = mx.get_active_memory() + current_cache = mx.get_cache_memory() + effective_current = current_active - current_cache + + if engine_type in ("embedding", "reranker", "audio_stt", "audio_tts", "audio_sts"): + overhead_pct = 0.05 + else: + overhead_pct = max(0.10, min(0.25, 0.30 - new_model_size_bytes / (200 * 1024**3))) + runtime_overhead = int(new_model_size_bytes * overhead_pct) + + projected = effective_current + new_model_size_bytes + runtime_overhead + utilization = projected / self._max_bytes if self._max_bytes > 0 else 0.0 + watermark = MemoryWatermark.from_utilization(utilization) + + diagnostics = { + "current_gb": round(current_active / 1024**3, 2), + "cache_gb": round(current_cache / 1024**3, 2), + "effective_current_gb": round(effective_current / 1024**3, 2), + "new_model_gb": round(new_model_size_bytes / 1024**3, 2), + "overhead_gb": round(runtime_overhead / 1024**3, 2), + "overhead_pct": round(overhead_pct * 100, 1), + "projected_gb": round(projected / 1024**3, 2), + "limit_gb": round(self._max_bytes / 1024**3, 2), + "utilization_pct": round(utilization * 100, 1), + "watermark": watermark.value, + "loaded_model_count": self._engine_pool.loaded_model_count, + } + + action_map = { + MemoryWatermark.GREEN: WatermarkAction.LOAD_DIRECTLY, + MemoryWatermark.YELLOW: WatermarkAction.RECLAIM_THEN_LOAD, + MemoryWatermark.RED: WatermarkAction.RESTART_THEN_LOAD, + MemoryWatermark.FATAL: WatermarkAction.QUEUE_AND_WAIT, + } + return action_map[watermark], diagnostics + + async def emergency_reclaim(self) -> bool: + """ + Emergency memory reclaim: aggressive GC + Metal cache clear. + + Uses the dedicated MLX executor to avoid concurrent Metal operations. + Returns True if memory dropped. + """ + logger.warning("Performing emergency memory reclaim...") + loop = asyncio.get_running_loop() + before = mx.get_active_memory() + + for _i in range(3): + gc.collect() + await loop.run_in_executor( + get_mlx_executor(), lambda: (mx.synchronize(), mx.clear_cache()) + ) + await asyncio.sleep(1.0) + + gc.collect() + await loop.run_in_executor( + get_mlx_executor(), lambda: (mx.synchronize(), mx.clear_cache()) + ) + await asyncio.sleep(2.0) + after = mx.get_active_memory() + + freed = before - after + logger.info( + f"Emergency reclaim complete: freed {_format_gb(freed)} " + f"({_format_gb(before)} -> {_format_gb(after)})" + ) + return after < before diff --git a/tests/test_engine_pool.py b/tests/test_engine_pool.py index 75331bfb..01d0f37a 100644 --- a/tests/test_engine_pool.py +++ b/tests/test_engine_pool.py @@ -523,6 +523,106 @@ def test_pinned_model_skipped_for_eviction(self, pool_with_entries): # model-a is skipped (pinned), model-b is selected assert victim == "model-b" + def test_find_lru_victim_skips_active_requests(self, pool_with_entries): + """Test that models with active requests are skipped during eviction.""" + # model-a has active requests + mock_engine_a = MagicMock() + mock_engine_a.has_active_requests.return_value = True + pool_with_entries._entries["model-a"].engine = mock_engine_a + pool_with_entries._entries["model-a"].last_access = 50 # Older + + # model-b has no active requests + mock_engine_b = MagicMock() + mock_engine_b.has_active_requests.return_value = False + pool_with_entries._entries["model-b"].engine = mock_engine_b + pool_with_entries._entries["model-b"].last_access = 200 # Newer + + victim = pool_with_entries._find_lru_victim() + # model-a skipped (active requests), model-b selected + assert victim == "model-b" + + def test_find_lru_victim_exclude_set(self, pool_with_entries): + """Test that excluded models are skipped during eviction.""" + mock_engine_a = MagicMock() + mock_engine_a.has_active_requests.return_value = False + pool_with_entries._entries["model-a"].engine = mock_engine_a + pool_with_entries._entries["model-a"].last_access = 50 + + mock_engine_b = MagicMock() + mock_engine_b.has_active_requests.return_value = False + pool_with_entries._entries["model-b"].engine = mock_engine_b + pool_with_entries._entries["model-b"].last_access = 200 + + victim = pool_with_entries._find_lru_victim(exclude={"model-a"}) + assert victim == "model-b" + + def test_find_lru_victim_all_active(self, pool_with_entries): + """Test that None is returned when all models have active requests.""" + for mid in ("model-a", "model-b"): + mock_engine = MagicMock() + mock_engine.has_active_requests.return_value = True + pool_with_entries._entries[mid].engine = mock_engine + pool_with_entries._entries[mid].last_access = 100 + + victim = pool_with_entries._find_lru_victim() + assert victim is None + + def test_find_lru_victim_no_has_active_requests(self, pool_with_entries): + """Test graceful handling when engine lacks has_active_requests.""" + mock_engine = MagicMock(spec=[]) # No has_active_requests + pool_with_entries._entries["model-a"].engine = mock_engine + pool_with_entries._entries["model-a"].last_access = 100 + + victim = pool_with_entries._find_lru_victim() + assert victim == "model-a" + + +class TestEnginePoolRestartState: + """Tests for restart request tracking.""" + + @pytest.fixture + def pool(self, small_mock_model_dir): + pool = EnginePool() + pool.discover_models(str(small_mock_model_dir)) + return pool + + def test_restart_defaults(self, pool): + """Test that restart state starts clean.""" + assert pool.restart_requested is False + assert pool.restart_reason == "" + + def test_request_and_clear_restart(self, pool): + """Test request/clear restart cycle.""" + pool.request_restart("memory barrier timeout") + assert pool.restart_requested is True + assert "memory barrier" in pool.restart_reason + + pool.clear_restart_request() + assert pool.restart_requested is False + assert pool.restart_reason == "" + + def test_last_eviction_default(self, pool): + """Test that last_eviction starts as None.""" + assert pool.last_eviction is None + + def test_get_loaded_model_details_empty(self, pool): + """Test model details when nothing loaded.""" + details = pool.get_loaded_model_details() + assert details == [] + + def test_get_loaded_model_details_with_loaded(self, pool): + """Test model details with loaded models.""" + mock_engine = MagicMock() + mock_engine.has_active_requests.return_value = False + pool._entries["model-a"].engine = mock_engine + pool._entries["model-a"].last_access = 1000.0 + + details = pool.get_loaded_model_details() + assert len(details) == 1 + assert details[0]["id"] == "model-a" + assert "est_gb" in details[0] + assert details[0]["active_requests"] is False + class TestEnginePoolAsync: """Async tests for EnginePool (mocked).""" diff --git a/tests/test_process_memory_enforcer.py b/tests/test_process_memory_enforcer.py index 9c28938f..c00b9f58 100644 --- a/tests/test_process_memory_enforcer.py +++ b/tests/test_process_memory_enforcer.py @@ -625,3 +625,42 @@ async def test_get_status_when_running(self, enforcer): assert status["enabled"] is True assert status["current_bytes"] == 5 * 1024**3 await enforcer.stop() + + +class TestMemoryWatermark: + """Tests for MemoryWatermark enum.""" + + def test_green_zone(self): + from omlx.process_memory_enforcer import MemoryWatermark + assert MemoryWatermark.from_utilization(0.0) == MemoryWatermark.GREEN + assert MemoryWatermark.from_utilization(0.5) == MemoryWatermark.GREEN + assert MemoryWatermark.from_utilization(0.64) == MemoryWatermark.GREEN + + def test_yellow_zone(self): + from omlx.process_memory_enforcer import MemoryWatermark + assert MemoryWatermark.from_utilization(0.65) == MemoryWatermark.YELLOW + assert MemoryWatermark.from_utilization(0.75) == MemoryWatermark.YELLOW + assert MemoryWatermark.from_utilization(0.79) == MemoryWatermark.YELLOW + + def test_red_zone(self): + from omlx.process_memory_enforcer import MemoryWatermark + assert MemoryWatermark.from_utilization(0.80) == MemoryWatermark.RED + assert MemoryWatermark.from_utilization(0.85) == MemoryWatermark.RED + assert MemoryWatermark.from_utilization(0.89) == MemoryWatermark.RED + + def test_fatal_zone(self): + from omlx.process_memory_enforcer import MemoryWatermark + assert MemoryWatermark.from_utilization(0.90) == MemoryWatermark.FATAL + assert MemoryWatermark.from_utilization(1.0) == MemoryWatermark.FATAL + assert MemoryWatermark.from_utilization(1.5) == MemoryWatermark.FATAL + + +class TestWatermarkAction: + """Tests for WatermarkAction enum values.""" + + def test_action_values(self): + from omlx.process_memory_enforcer import WatermarkAction + assert WatermarkAction.LOAD_DIRECTLY.value == "load_directly" + assert WatermarkAction.RECLAIM_THEN_LOAD.value == "reclaim_then_load" + assert WatermarkAction.RESTART_THEN_LOAD.value == "restart_then_load" + assert WatermarkAction.QUEUE_AND_WAIT.value == "queue_and_wait"