diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 020a7170..5917f76b 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -72,6 +72,22 @@ class DeleteSubKeyRequest(BaseModel): key: str +class CacheProbeRequest(BaseModel): + """Request model for probing per-prompt cache state. + + Tokenizes a chat message list with the target model's tokenizer, then + classifies each block's location in the cache hierarchy: + - Hot SSD (in-RAM copy of SSD cache, ready to mount without disk read) + - Disk SSD (persisted only, needs disk read to reuse) + - Cold (fully uncached — would require full prefill) + """ + + model_id: str + messages: List[Dict[str, Any]] + tools: Optional[List[Dict[str, Any]]] = None + chat_template_kwargs: Optional[Dict[str, Any]] = None + + class ModelSettingsRequest(BaseModel): """Request model for updating per-model settings.""" @@ -2772,6 +2788,187 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): return {"status": "ok", "total_deleted": total_deleted} +@router.post("/api/cache/probe") +async def probe_cache( + request: CacheProbeRequest, + is_admin: bool = Depends(require_admin), +): + """Probe cache state for a chat message list. + + Classifies each block of the rendered prompt into one of three buckets: + - ``blocks_ssd_hot``: in the SSD manager's hot cache (RAM copy of cold + blocks, ready to mount without disk read) + - ``blocks_ssd_disk``: only in the SSD index on disk + - ``blocks_cold``: not cached anywhere (requires full prefill) + + The split is computed via a walk of the chain-hashed block sequence — the + same hashing the scheduler uses at prefill time. The model must be loaded + for the probe to run; unloaded models return ``model_loaded: false``. + """ + engine_pool = _get_engine_pool() + if engine_pool is None: + raise HTTPException(status_code=503, detail="Engine pool not initialized") + + entry = engine_pool._entries.get(request.model_id) + if entry is None: + raise HTTPException( + status_code=404, detail=f"Model not found: {request.model_id}" + ) + if entry.engine is None: + return { + "model_id": request.model_id, + "model_loaded": False, + "reason": "Model is not loaded — load it to enable cache probing.", + } + + engine = entry.engine + tokenizer = getattr(engine, "_tokenizer", None) + if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): + raise HTTPException( + status_code=400, + detail="Model tokenizer does not support chat templating.", + ) + + # Reach into the scheduler to access the prefix index and SSD manager. + async_core = getattr(engine, "_engine", None) + core = getattr(async_core, "engine", None) if async_core is not None else None + scheduler = getattr(core, "scheduler", None) if core is not None else None + if scheduler is None: + raise HTTPException( + status_code=500, detail="Scheduler unavailable for loaded model." + ) + + prefix_cache = getattr(scheduler, "block_aware_cache", None) + ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) + paged_cache = getattr(scheduler, "paged_cache_manager", None) + block_size = getattr( + getattr(scheduler, "config", None), "paged_cache_block_size", 0 + ) + if not block_size and prefix_cache is not None: + block_size = getattr(prefix_cache, "block_size", 0) + if not block_size: + raise HTTPException( + status_code=500, + detail="Cache block size unavailable — cache may not be enabled.", + ) + + # Render + tokenize the prompt using the same path as generation so the + # hashes line up with what the scheduler would produce at prefill. + try: + messages = request.messages + if hasattr(engine, "_preprocess_messages"): + messages = engine._preprocess_messages(messages) + try: + from ..api.tool_calling import convert_tools_for_template # type: ignore + template_tools = ( + convert_tools_for_template(request.tools) if request.tools else None + ) + except Exception: + template_tools = request.tools or None + if hasattr(engine, "_apply_chat_template"): + prompt = engine._apply_chat_template( + messages, + template_tools, + chat_template_kwargs=request.chat_template_kwargs, + ) + else: + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + token_ids = list(tokenizer.encode(prompt)) + except Exception as exc: + raise HTTPException( + status_code=400, detail=f"Failed to tokenize messages: {exc}" + ) + + total_tokens = len(token_ids) + if total_tokens == 0: + return { + "model_id": request.model_id, + "model_loaded": True, + "total_tokens": 0, + "block_size": block_size, + "total_blocks": 0, + "blocks_ssd_hot": 0, + "blocks_ssd_disk": 0, + "blocks_cold": 0, + "ssd_hit_tokens": 0, + "cold_tokens": 0, + } + + # Compute chain-hashed block sequence. + from ..cache.paged_cache import compute_block_hash + + model_name = getattr(paged_cache, "model_name", None) if paged_cache else None + ssd_index = getattr(ssd_manager, "_index", None) if ssd_manager else None + ssd_hot = getattr(ssd_manager, "_hot_cache", None) if ssd_manager else None + + # The cache is a contiguous prefix (each block chain-hashed from the + # previous), so we walk block-by-block until the first retrievability + # miss — after that, every subsequent block is necessarily cold. + # + # Ground truth for "cached" in paged-SSD mode is retrievability: + # hot_cache (RAM copy) OR ssd_index (on disk). BlockAwarePrefixCache's + # internal prefix index is deliberately NOT consulted — it tracks every + # hash the scheduler has seen and isn't cleared by clear_ssd_cache(), + # so relying on it would report false positives after a manual wipe. + blocks_ssd_hot = 0 + blocks_ssd_disk = 0 + ssd_hit_tokens = 0 + + parent_hash = b"" + total_blocks = (total_tokens + block_size - 1) // block_size + + for start in range(0, total_tokens, block_size): + end = min(start + block_size, total_tokens) + block_tokens = token_ids[start:end] + if not block_tokens: + break + + block_hash = compute_block_hash( + parent_hash, + block_tokens, + extra_keys=None, + model_name=model_name, + ) + parent_hash = block_hash + + in_ssd_hot = ssd_hot is not None and block_hash in ssd_hot + in_ssd_disk = False + if ssd_index is not None: + try: + in_ssd_disk = ssd_index.contains(block_hash) + except Exception: + in_ssd_disk = False + + if not (in_ssd_hot or in_ssd_disk): + break + + if in_ssd_hot: + blocks_ssd_hot += 1 + else: + blocks_ssd_disk += 1 + ssd_hit_tokens += len(block_tokens) + + cached_blocks = blocks_ssd_hot + blocks_ssd_disk + blocks_cold = max(total_blocks - cached_blocks, 0) + + return { + "model_id": request.model_id, + "model_loaded": True, + "total_tokens": total_tokens, + "block_size": block_size, + "total_blocks": total_blocks, + "blocks_ssd_hot": blocks_ssd_hot, + "blocks_ssd_disk": blocks_ssd_disk, + "blocks_cold": blocks_cold, + "ssd_hit_tokens": ssd_hit_tokens, + "cold_tokens": max(total_tokens - ssd_hit_tokens, 0), + } + + # ============================================================================= # HuggingFace Downloader API Routes # =============================================================================