Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions omlx/cache/paged_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,32 @@
BlockHash = NewType("BlockHash", bytes)


def resolve_block_extra_keys(
block_end: int,
extra_keys: Optional[Tuple[Any, ...]] = None,
extra_key_token_start: Optional[int] = None,
extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None,
) -> Optional[Tuple[Any, ...]]:
"""Resolve which cache key salt applies to a block ending at ``block_end``.

``extra_key_ranges`` takes precedence over ``extra_keys`` and is intended
for segmented VLM cache keying.
"""
if extra_key_ranges:
selected = None
for start, keys in extra_key_ranges:
if block_end > start:
selected = keys
else:
break
return selected
if extra_keys is not None and (
extra_key_token_start is None or block_end > extra_key_token_start
):
return extra_keys
return None


def compute_block_hash(
parent_hash: Optional[BlockHash],
token_ids: List[int],
Expand Down Expand Up @@ -923,6 +949,8 @@ def get_computed_blocks(
self,
token_ids: List[int],
extra_keys: Optional[Tuple[Any, ...]] = None,
extra_key_token_start: Optional[int] = None,
extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None,
) -> Tuple[List[CacheBlock], int]:
"""
Find cached blocks for a token prefix (vLLM style).
Expand All @@ -948,11 +976,17 @@ def get_computed_blocks(
start = i * self.block_size
end = start + self.block_size
block_tokens = token_ids[start:end]
block_extra_keys = resolve_block_extra_keys(
end,
extra_keys=extra_keys,
extra_key_token_start=extra_key_token_start,
extra_key_ranges=extra_key_ranges,
)

# Compute expected hash
block_hash = compute_block_hash(
parent_hash, block_tokens,
extra_keys=extra_keys, model_name=self.model_name,
extra_keys=block_extra_keys, model_name=self.model_name,
)

# Look up in cache
Expand Down Expand Up @@ -1103,14 +1137,19 @@ def find_shared_prefix(
self,
tokens: List[int],
extra_keys: Optional[Tuple[Any, ...]] = None,
extra_key_token_start: Optional[int] = None,
extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None,
) -> Tuple[List[int], List[int]]:
"""
Find shared prefix blocks for a token sequence.

Uses get_computed_blocks for consistent chain-hash lookup.
"""
cached_blocks, num_cached_tokens = self.get_computed_blocks(
tokens, extra_keys=extra_keys
tokens,
extra_keys=extra_keys,
extra_key_token_start=extra_key_token_start,
extra_key_ranges=extra_key_ranges,
)

shared_block_ids = [b.block_id for b in cached_blocks]
Expand Down
34 changes: 29 additions & 5 deletions omlx/cache/prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@

from .interface import CacheManager
from .paged_ssd_cache import PagedSSDCacheManager
from .paged_cache import BlockTable, CacheBlock, PagedCacheManager, compute_block_hash
from .paged_cache import (
BlockTable,
CacheBlock,
PagedCacheManager,
compute_block_hash,
resolve_block_extra_keys,
)
from .stats import BaseCacheStats, PrefixCacheStats
from .type_handlers import CacheType, CacheTypeHandler
from .type_registry import CacheTypeRegistry
Expand Down Expand Up @@ -228,6 +234,8 @@ def fetch_cache(
request_id: str,
tokens: List[int],
extra_keys: Optional[Tuple[Any, ...]] = None,
extra_key_token_start: Optional[int] = None,
extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None,
) -> Tuple[Optional[BlockTable], List[int]]:
"""
Find cached prefix blocks for the given tokens.
Expand All @@ -247,7 +255,10 @@ def fetch_cache(

# Try to find shared prefix blocks
shared_block_ids, remaining = self.paged_cache.find_shared_prefix(
tokens, extra_keys=extra_keys
tokens,
extra_keys=extra_keys,
extra_key_token_start=extra_key_token_start,
extra_key_ranges=extra_key_ranges,
)

if shared_block_ids:
Expand Down Expand Up @@ -311,6 +322,8 @@ def store_cache(
model_cache_config: Optional[ModelCacheConfig] = None,
boundary_snapshots: Optional[Dict[int, List[Any]]] = None,
extra_keys: Optional[Tuple[Any, ...]] = None,
extra_key_token_start: Optional[int] = None,
extra_key_ranges: Optional[List[Tuple[int, Tuple[Any, ...]]]] = None,
) -> Optional[BlockTable]:
"""
Store computed cache for future reuse.
Expand Down Expand Up @@ -433,9 +446,20 @@ def store_cache(
if prev_block and prev_block.block_hash:
parent_hash = prev_block.block_hash

block_extra_keys = resolve_block_extra_keys(
global_end,
extra_keys=extra_keys,
extra_key_token_start=extra_key_token_start,
extra_key_ranges=extra_key_ranges,
)

# Check if this block already exists (deduplication)
if len(block_tokens) == self.block_size:
existing_block = self.paged_cache.find_cached_block(block_tokens, parent_hash)
existing_block = self.paged_cache.find_cached_block(
block_tokens,
parent_hash,
extra_keys=block_extra_keys,
)
if existing_block:
# Reuse existing block
self.paged_cache.increment_ref(existing_block.block_id)
Expand All @@ -462,13 +486,13 @@ def store_cache(
# Compute chain hash for this block
block.block_hash = compute_block_hash(
parent_hash, block_tokens,
extra_keys=extra_keys, model_name=self.paged_cache.model_name,
extra_keys=block_extra_keys, model_name=self.paged_cache.model_name,
)

# Register hash for full blocks (for deduplication)
if len(block_tokens) == self.block_size:
self.paged_cache.register_block_hash(
block, block_tokens, parent_hash, extra_keys=extra_keys
block, block_tokens, parent_hash, extra_keys=block_extra_keys
)

# Extract tensor slice and save to paged SSD
Expand Down
Loading