From fcb767518f1e2e8e36c42b733ba29bb325964aef Mon Sep 17 00:00:00 2001 From: adi Date: Sat, 11 Apr 2026 10:51:07 +0100 Subject: [PATCH] feat(admin): add cache probe endpoint for prompt prefix lookup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit POST /admin/api/cache/probe accepts {model_id, messages, tools?, chat_template_kwargs?} and reports how much of the rendered prompt is already resident in the loaded model's SSD cache, broken down by tier (hot cache / disk index / cold). The walk chain-hashes each block the same way the scheduler does at prefill so the answer matches what a real request would see. Motivating use case is a cache-aware chat UI: when a user is about to send (or when branching), show whether the prefill will hit the cache or pay the full cost. Works for both batched and VLM engines and requires admin auth. Ground truth for "cached" is retrievability via the paged SSD cache manager — hot_cache for RAM-resident blocks or _index for on-disk files. BlockAwarePrefixCache._prefix_index is intentionally not consulted because it survives clear_ssd_cache() and would report false positives after a manual wipe. Co-Authored-By: Claude Opus 4.6 (1M context) --- omlx/admin/routes.py | 197 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/omlx/admin/routes.py b/omlx/admin/routes.py index 020a7170..5917f76b 100644 --- a/omlx/admin/routes.py +++ b/omlx/admin/routes.py @@ -72,6 +72,22 @@ class DeleteSubKeyRequest(BaseModel): key: str +class CacheProbeRequest(BaseModel): + """Request model for probing per-prompt cache state. + + Tokenizes a chat message list with the target model's tokenizer, then + classifies each block's location in the cache hierarchy: + - Hot SSD (in-RAM copy of SSD cache, ready to mount without disk read) + - Disk SSD (persisted only, needs disk read to reuse) + - Cold (fully uncached — would require full prefill) + """ + + model_id: str + messages: List[Dict[str, Any]] + tools: Optional[List[Dict[str, Any]]] = None + chat_template_kwargs: Optional[Dict[str, Any]] = None + + class ModelSettingsRequest(BaseModel): """Request model for updating per-model settings.""" @@ -2772,6 +2788,187 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)): return {"status": "ok", "total_deleted": total_deleted} +@router.post("/api/cache/probe") +async def probe_cache( + request: CacheProbeRequest, + is_admin: bool = Depends(require_admin), +): + """Probe cache state for a chat message list. + + Classifies each block of the rendered prompt into one of three buckets: + - ``blocks_ssd_hot``: in the SSD manager's hot cache (RAM copy of cold + blocks, ready to mount without disk read) + - ``blocks_ssd_disk``: only in the SSD index on disk + - ``blocks_cold``: not cached anywhere (requires full prefill) + + The split is computed via a walk of the chain-hashed block sequence — the + same hashing the scheduler uses at prefill time. The model must be loaded + for the probe to run; unloaded models return ``model_loaded: false``. + """ + engine_pool = _get_engine_pool() + if engine_pool is None: + raise HTTPException(status_code=503, detail="Engine pool not initialized") + + entry = engine_pool._entries.get(request.model_id) + if entry is None: + raise HTTPException( + status_code=404, detail=f"Model not found: {request.model_id}" + ) + if entry.engine is None: + return { + "model_id": request.model_id, + "model_loaded": False, + "reason": "Model is not loaded — load it to enable cache probing.", + } + + engine = entry.engine + tokenizer = getattr(engine, "_tokenizer", None) + if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): + raise HTTPException( + status_code=400, + detail="Model tokenizer does not support chat templating.", + ) + + # Reach into the scheduler to access the prefix index and SSD manager. + async_core = getattr(engine, "_engine", None) + core = getattr(async_core, "engine", None) if async_core is not None else None + scheduler = getattr(core, "scheduler", None) if core is not None else None + if scheduler is None: + raise HTTPException( + status_code=500, detail="Scheduler unavailable for loaded model." + ) + + prefix_cache = getattr(scheduler, "block_aware_cache", None) + ssd_manager = getattr(scheduler, "paged_ssd_cache_manager", None) + paged_cache = getattr(scheduler, "paged_cache_manager", None) + block_size = getattr( + getattr(scheduler, "config", None), "paged_cache_block_size", 0 + ) + if not block_size and prefix_cache is not None: + block_size = getattr(prefix_cache, "block_size", 0) + if not block_size: + raise HTTPException( + status_code=500, + detail="Cache block size unavailable — cache may not be enabled.", + ) + + # Render + tokenize the prompt using the same path as generation so the + # hashes line up with what the scheduler would produce at prefill. + try: + messages = request.messages + if hasattr(engine, "_preprocess_messages"): + messages = engine._preprocess_messages(messages) + try: + from ..api.tool_calling import convert_tools_for_template # type: ignore + template_tools = ( + convert_tools_for_template(request.tools) if request.tools else None + ) + except Exception: + template_tools = request.tools or None + if hasattr(engine, "_apply_chat_template"): + prompt = engine._apply_chat_template( + messages, + template_tools, + chat_template_kwargs=request.chat_template_kwargs, + ) + else: + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + token_ids = list(tokenizer.encode(prompt)) + except Exception as exc: + raise HTTPException( + status_code=400, detail=f"Failed to tokenize messages: {exc}" + ) + + total_tokens = len(token_ids) + if total_tokens == 0: + return { + "model_id": request.model_id, + "model_loaded": True, + "total_tokens": 0, + "block_size": block_size, + "total_blocks": 0, + "blocks_ssd_hot": 0, + "blocks_ssd_disk": 0, + "blocks_cold": 0, + "ssd_hit_tokens": 0, + "cold_tokens": 0, + } + + # Compute chain-hashed block sequence. + from ..cache.paged_cache import compute_block_hash + + model_name = getattr(paged_cache, "model_name", None) if paged_cache else None + ssd_index = getattr(ssd_manager, "_index", None) if ssd_manager else None + ssd_hot = getattr(ssd_manager, "_hot_cache", None) if ssd_manager else None + + # The cache is a contiguous prefix (each block chain-hashed from the + # previous), so we walk block-by-block until the first retrievability + # miss — after that, every subsequent block is necessarily cold. + # + # Ground truth for "cached" in paged-SSD mode is retrievability: + # hot_cache (RAM copy) OR ssd_index (on disk). BlockAwarePrefixCache's + # internal prefix index is deliberately NOT consulted — it tracks every + # hash the scheduler has seen and isn't cleared by clear_ssd_cache(), + # so relying on it would report false positives after a manual wipe. + blocks_ssd_hot = 0 + blocks_ssd_disk = 0 + ssd_hit_tokens = 0 + + parent_hash = b"" + total_blocks = (total_tokens + block_size - 1) // block_size + + for start in range(0, total_tokens, block_size): + end = min(start + block_size, total_tokens) + block_tokens = token_ids[start:end] + if not block_tokens: + break + + block_hash = compute_block_hash( + parent_hash, + block_tokens, + extra_keys=None, + model_name=model_name, + ) + parent_hash = block_hash + + in_ssd_hot = ssd_hot is not None and block_hash in ssd_hot + in_ssd_disk = False + if ssd_index is not None: + try: + in_ssd_disk = ssd_index.contains(block_hash) + except Exception: + in_ssd_disk = False + + if not (in_ssd_hot or in_ssd_disk): + break + + if in_ssd_hot: + blocks_ssd_hot += 1 + else: + blocks_ssd_disk += 1 + ssd_hit_tokens += len(block_tokens) + + cached_blocks = blocks_ssd_hot + blocks_ssd_disk + blocks_cold = max(total_blocks - cached_blocks, 0) + + return { + "model_id": request.model_id, + "model_loaded": True, + "total_tokens": total_tokens, + "block_size": block_size, + "total_blocks": total_blocks, + "blocks_ssd_hot": blocks_ssd_hot, + "blocks_ssd_disk": blocks_ssd_disk, + "blocks_cold": blocks_cold, + "ssd_hit_tokens": ssd_hit_tokens, + "cold_tokens": max(total_tokens - ssd_hit_tokens, 0), + } + + # ============================================================================= # HuggingFace Downloader API Routes # =============================================================================