diff --git a/python/tokenspeed/runtime/configs/deepseek_v4_cache_spec.py b/python/tokenspeed/runtime/configs/deepseek_v4_cache_spec.py index 20697c6f8..547597c42 100644 --- a/python/tokenspeed/runtime/configs/deepseek_v4_cache_spec.py +++ b/python/tokenspeed/runtime/configs/deepseek_v4_cache_spec.py @@ -78,17 +78,20 @@ def build_v4_cache_specs( unique_compress_ratios = sorted({int(r) for r in layer_ratio if int(r) > 1}) specs: List[PagedCacheGroupSpec] = [ + # SWA kv: trailing window only -> State family. PagedCacheGroupSpec( group_id=V4_SWA_KV_GROUP_ID, retention="sliding_window", rows_per_page=V4_KERNEL_BLOCK_ROWS, entry_stride_tokens=1, sliding_window_tokens=swa_window, + family="state", ), ] for ratio in unique_compress_ratios: if ratio not in _COMPRESSOR_STATE_WINDOW_TOKENS: raise ValueError(f"unsupported DeepSeek V4 compress_ratio={ratio}") + # Compressor state: tail buffer -> State family. specs.append( PagedCacheGroupSpec( group_id=v4_compressor_state_group_id(ratio), @@ -96,8 +99,10 @@ def build_v4_cache_specs( rows_per_page=_COMPRESSOR_STATE_ROWS_PER_PAGE[ratio], entry_stride_tokens=1, sliding_window_tokens=_COMPRESSOR_STATE_WINDOW_TOKENS[ratio], + family="state", ) ) + # Compressed kv: full-history chain (indexer K shares this group). specs.append( PagedCacheGroupSpec( group_id=v4_compressed_kv_group_id(ratio), @@ -105,9 +110,11 @@ def build_v4_cache_specs( rows_per_page=_compressed_kernel_block_size(ratio), entry_stride_tokens=ratio, sliding_window_tokens=None, + family="history", ) ) if 4 in unique_compress_ratios: + # Indexer compressor state: tail buffer -> State family. specs.append( PagedCacheGroupSpec( group_id=V4_INDEXER_COMPRESSOR_STATE_GROUP_ID, @@ -115,6 +122,7 @@ def build_v4_cache_specs( rows_per_page=_COMPRESSOR_STATE_ROWS_PER_PAGE[4], entry_stride_tokens=1, sliding_window_tokens=_COMPRESSOR_STATE_WINDOW_TOKENS[4], + family="state", ) ) return specs diff --git a/python/tokenspeed/runtime/configs/paged_cache_spec.py b/python/tokenspeed/runtime/configs/paged_cache_spec.py index a83dc1b3b..744fd56a9 100644 --- a/python/tokenspeed/runtime/configs/paged_cache_spec.py +++ b/python/tokenspeed/runtime/configs/paged_cache_spec.py @@ -22,6 +22,7 @@ from tokenspeed.runtime.utils.common import ceil_div Retention = Literal["full_history", "sliding_window"] +Family = Literal["history", "state"] @dataclass(frozen=True) @@ -31,6 +32,8 @@ class PagedCacheGroupSpec: rows_per_page: int entry_stride_tokens: int sliding_window_tokens: Optional[int] + # History groups form a chain; State groups only need the trailing window. + family: Family = "history" _PAGED_CACHE_GROUP_DUMMY_PAGES = 1 diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index 859002b47..c697364c9 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -49,6 +49,7 @@ cache_sync_debug_enabled, make_config, pool_to_paged_cache_groups, + pool_to_prefix_cache_adjunct_spec, pop_common_cache_event_payloads, ) from tokenspeed.runtime.execution.distributed_initializer import ( @@ -275,6 +276,12 @@ def __init__( enable_mixed_prefill_decode = ( server_args.enable_mixed_batch and server_args.speculative_algorithm is None ) + # Adjunct enabled only when pool opts in AND prefix-caching switch is on. + paged_cache_groups = pool_to_paged_cache_groups(token_to_kv_pool) + prefix_cache_adjunct = None + required_groups = token_to_kv_pool.prefix_cache_required_group_ids + if required_groups is not None and server_args.enable_prefix_caching: + prefix_cache_adjunct = pool_to_prefix_cache_adjunct_spec(required_groups) scheduler_cfg = make_config( num_device_pages=self.max_total_num_tokens // server_args.block_size, max_scheduled_tokens=server_args.chunked_prefill_size, @@ -295,8 +302,9 @@ def __init__( enable_mamba=has_mamba, mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, - paged_cache_groups=pool_to_paged_cache_groups(token_to_kv_pool), + paged_cache_groups=paged_cache_groups, enable_mixed_prefill_decode=enable_mixed_prefill_decode, + prefix_cache_adjunct=prefix_cache_adjunct, ) logger.info( "Scheduler config: page_size=%s num_device_pages=%s " diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 653a1f191..d8687dbcc 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -30,7 +30,9 @@ ExecutionEvent, ForwardEvent, PagedCacheGroupConfig, + PagedCacheGroupFamily, PagedCacheRetention, + PrefixCacheAdjunctSpec, RequestSpec, SchedulerConfig, ) @@ -67,6 +69,7 @@ def make_config( mamba_pool_total_chunks: int = 0, paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None, enable_mixed_prefill_decode: bool = False, + prefix_cache_adjunct: "PrefixCacheAdjunctSpec | None" = None, ) -> SchedulerConfig: cfg = SchedulerConfig() cfg.num_device_pages = num_device_pages @@ -96,12 +99,15 @@ def make_config( cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode if paged_cache_groups: cfg.paged_cache_groups = list(paged_cache_groups) + # Opt-in; unset means paged-cache groups are transport-only. + if prefix_cache_adjunct is not None: + cfg.prefix_cache_adjunct = prefix_cache_adjunct return cfg def pool_to_paged_cache_groups(pool: Any) -> list: """Convert a KV pool's paged_cache_group_specs to scheduler configs.""" - specs = getattr(pool, "paged_cache_group_specs", ()) + specs = pool.paged_cache_group_specs if not specs: return [] counts = pool.paged_cache_group_page_counts @@ -116,12 +122,23 @@ def pool_to_paged_cache_groups(pool: Any) -> list: f"pool_to_paged_cache_groups: unsupported retention " f"{spec.retention!r} for group {spec.group_id!r}" ) + family_str = getattr(spec, "family", "history") + if family_str == "history": + family = PagedCacheGroupFamily.History + elif family_str == "state": + family = PagedCacheGroupFamily.State + else: + raise ValueError( + f"pool_to_paged_cache_groups: unsupported family " + f"{family_str!r} for group {spec.group_id!r}" + ) kwargs = dict( group_id=spec.group_id, rows_per_page=int(spec.rows_per_page), entry_stride_tokens=int(spec.entry_stride_tokens), total_pages=int(counts[spec.group_id]), retention=retention, + family=family, ) if spec.retention == "sliding_window": kwargs["sliding_window_tokens"] = int(spec.sliding_window_tokens) @@ -129,6 +146,19 @@ def pool_to_paged_cache_groups(pool: Any) -> list: return out +def pool_to_prefix_cache_adjunct_spec( + required_group_ids: Sequence[str], +) -> "PrefixCacheAdjunctSpec": + """Build a PrefixCacheAdjunctSpec from a non-empty required-group-id list.""" + if not required_group_ids: + raise ValueError( + "pool_to_prefix_cache_adjunct_spec: required_group_ids must be non-empty" + ) + spec = PrefixCacheAdjunctSpec() + spec.required_groups = [str(gid) for gid in required_group_ids] + return spec + + def make_extend_result_event(request_id: str, tokens: list[int] = ()) -> None: fe = ForwardEvent.ExtendResult() fe.request_id = request_id diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index da006abce..647664a09 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -224,9 +224,7 @@ def __init__( self.world_size = config.world_size # Backends alias their cache_seqlens buffer. Draft backend aliases # the drafter-owned draft_seq_lens to keep InputBuffers read-only. - paged_cache_group_specs = tuple( - getattr(token_to_kv_pool, "paged_cache_group_specs", ()) or () - ) + paged_cache_group_specs = tuple(token_to_kv_pool.paged_cache_group_specs) try: attn_backend.init_cuda_graph_state( self.max_bs, @@ -241,7 +239,7 @@ def __init__( ) if draft_attn_backend is not None: draft_paged_cache_group_specs = tuple( - getattr(draft_token_to_kv_pool, "paged_cache_group_specs", ()) or () + draft_token_to_kv_pool.paged_cache_group_specs ) try: draft_attn_backend.init_cuda_graph_state( @@ -416,7 +414,7 @@ def run_once(): return graph, out def _capture_paged_cache_block_tables(self, bs: int, pool) -> dict | None: - specs = tuple(getattr(pool, "paged_cache_group_specs", ()) or ()) + specs = tuple(pool.paged_cache_group_specs) if not specs: return None out = {} @@ -447,10 +445,9 @@ def _init_capture_metadata(self, bs: int): bs, self.token_to_kv_pool, ) - if paged_cache_block_tables is not None and getattr( - self.attn_backend, - "uses_paged_cache_groups", - False, + if ( + paged_cache_block_tables is not None + and self.attn_backend.uses_paged_cache_groups ): capture_kwargs["paged_cache_block_tables"] = paged_cache_block_tables self.attn_backend.init_forward_metadata_capture_cuda_graph( @@ -468,10 +465,9 @@ def _init_capture_metadata(self, bs: int): bs, self.draft_token_to_kv_pool, ) - if draft_paged_cache_block_tables is not None and getattr( - self.draft_attn_backend, - "uses_paged_cache_groups", - False, + if ( + draft_paged_cache_block_tables is not None + and self.draft_attn_backend.uses_paged_cache_groups ): draft_kwargs["paged_cache_block_tables"] = ( draft_paged_cache_block_tables @@ -581,7 +577,7 @@ def _init_replay_metadata( kwargs["paged_cache_block_table_base_offsets"] = ( paged_cache_block_table_base_offsets ) - if getattr(self.attn_backend, "uses_padded_decode_token_mask", False): + if self.attn_backend.uses_padded_decode_token_mask: kwargs["actual_bs"] = actual_bs self.attn_backend.init_forward_metadata_replay_cuda_graph( padded_bs, @@ -790,7 +786,7 @@ def __call__( if ( bs == 0 and paged_cache_block_tables is None - and getattr(self.attn_backend, "uses_paged_cache_groups", False) + and self.attn_backend.uses_paged_cache_groups ): paged_cache_block_tables = self._capture_paged_cache_block_tables( padded_bs, @@ -855,12 +851,12 @@ def __call__( spec_info=spec_info, paged_cache_block_tables=( paged_cache_block_tables - if getattr(self.attn_backend, "uses_paged_cache_groups", False) + if self.attn_backend.uses_paged_cache_groups else None ), paged_cache_block_table_base_offsets=( paged_cache_block_table_base_offsets - if getattr(self.attn_backend, "uses_paged_cache_groups", False) + if self.attn_backend.uses_paged_cache_groups else None ), **mamba_kwargs, diff --git a/python/tokenspeed/runtime/layers/attention/backends/base.py b/python/tokenspeed/runtime/layers/attention/backends/base.py index 3372d96c1..d73c666e6 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/base.py +++ b/python/tokenspeed/runtime/layers/attention/backends/base.py @@ -36,6 +36,9 @@ class AttentionBackend(ABC): """The base class of attention backends""" + uses_paged_cache_groups: bool = False + uses_padded_decode_token_mask: bool = False + def __init__(self, config: BaseAttnConfig) -> None: self.device = config.device self.num_qo_heads = config.num_attention_heads // config.attn_tp_size diff --git a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py index 5383297e1..2ddd0cd1b 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py @@ -1419,8 +1419,8 @@ def _refresh_cuda_graph_base_offsets( ) -> dict[str, torch.Tensor]: """Refresh persistent base-offset buffers from per-step input. - Sliding groups whose key is missing fall back to 0 (legacy - absolute scheduler binding). Returns the [:bs] views keyed by gid. + Sliding groups whose key is missing fall back to 0. Returns the [:bs] + views keyed by gid. """ out: dict[str, torch.Tensor] = {} for gid, buf in self._cuda_graph_paged_cache_base_offsets.items(): diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/base.py b/python/tokenspeed/runtime/layers/attention/kv_cache/base.py index effabc616..21bb08e78 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/base.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/base.py @@ -24,6 +24,7 @@ import torch +from tokenspeed.runtime.configs.paged_cache_spec import PagedCacheGroupSpec from tokenspeed.runtime.layers.paged_attention import PagedAttention from tokenspeed.runtime.utils import get_colorful_logger @@ -36,6 +37,9 @@ class BaseTokenToKVPool: """A memory pool that maps a token location to its kv cache data.""" + paged_cache_group_specs: tuple[PagedCacheGroupSpec, ...] = () + paged_cache_group_page_counts: dict[str, int] = {} + def __init__( self, size: int, @@ -106,6 +110,11 @@ def load_cpu_copy( ) -> None: raise NotImplementedError() + @property + def prefix_cache_required_group_ids(self) -> tuple[str, ...] | None: + """None means adjunct disabled; subclasses return required group ids.""" + return None + # Buffer metadata used by prefill/decode disaggregation. def get_contiguous_buf_infos(self): raise NotImplementedError() diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py index d180694c9..7979fa752 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py @@ -279,8 +279,7 @@ def _split_paged_cache_block_tables_into_v4_metadata( Returns (swa, {ratio: compressor_state}, indexer_state, swa_base, {ratio: compressor_state_base}, indexer_state_base). Unknown group ids - are ignored. Base offsets are None / missing when the input lacks them - (legacy scheduler binding). + are ignored. Base offsets are None / missing when the input lacks them. """ offsets = paged_cache_block_table_base_offsets or {} swa = paged_cache_block_tables.get(V4_SWA_KV_GROUP_ID) @@ -593,11 +592,12 @@ def deepseek_v4_cache_layout_from_config( class DeepseekV4TokenToKVPool(BaseTokenToKVPool): """DeepSeek V4 fp8_ds_mla cache pool. - TokenSpeed keeps the SWA, compressed, compressor-state, and CSA indexer - caches in one V4-only pool so ordinary MLA models keep their existing cache - contract untouched. Compressed caches currently reuse the request page table; - this is correctness-first and leaves ratio-specific allocation for the - optimized follow-up. + TokenSpeed keeps SWA, compressed, compressor-state, and CSA indexer caches + in dedicated per-group paged pools (see PagedCacheGroup* on the scheduler + side and ``build_v4_cache_specs`` here), keeping ordinary MLA models on + their existing single-pool contract. The ``indexer_kv_buffer`` shares its + page table and page-count budget with the ``v4.c{ratio}a.compressed_kv`` + group rather than owning a separate group of its own. """ def __init__( @@ -790,6 +790,11 @@ def _group_pages(group_id: str, default: int) -> int: self.compressed_block_sizes, ) + @property + def prefix_cache_required_group_ids(self) -> tuple[str, ...]: + """All V4 paged-cache groups must be present for a snapshot to be complete.""" + return tuple(str(spec.group_id) for spec in self.paged_cache_group_specs) + def _require( self, buffers: list[torch.Tensor | None], layer_id: int, name: str ) -> torch.Tensor: diff --git a/test/runtime/test_v4_prefix_cache_metadata.py b/test/runtime/test_v4_prefix_cache_metadata.py new file mode 100644 index 000000000..106e4d407 --- /dev/null +++ b/test/runtime/test_v4_prefix_cache_metadata.py @@ -0,0 +1,162 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +"""V4 prefix-cache metadata tests for per-group block-table / +base-offset wiring exposed by the scheduler when two same-prefix requests +land on the V4 paged-cache adjunct. + +The scheduler is stood up via the nanobind bindings and the assertions read +the V4 metadata observable by Python runtime / backend code. No GPU forwards. +""" + +from __future__ import annotations + +import unittest + +import pytest + +try: + from tokenspeed_scheduler import ( # type: ignore + ExecutionEvent, + ForwardEvent, + PagedCacheGroupConfig, + PagedCacheRetention, + PrefixCacheAdjunctSpec, + RequestSpec, + Scheduler, + SchedulerConfig, + ) + + _BINDING_AVAILABLE = True +except ImportError: # pragma: no cover - bindings unbuilt + _BINDING_AVAILABLE = False + + +pytestmark = pytest.mark.skipif( + not _BINDING_AVAILABLE, + reason="tokenspeed_scheduler nanobind bindings unavailable in this env", +) + + +def _make_two_group_config() -> "SchedulerConfig": + cfg = SchedulerConfig() + cfg.page_size = 2 + cfg.num_device_pages = 64 + cfg.num_host_pages = 64 + cfg.max_scheduled_tokens = 16 + cfg.max_batch_size = 8 + cfg.paged_cache_groups = [ + PagedCacheGroupConfig( + group_id="fh", + rows_per_page=4, + entry_stride_tokens=1, + total_pages=32, + retention=PagedCacheRetention.FullHistory, + ), + PagedCacheGroupConfig( + group_id="swa", + rows_per_page=2, + entry_stride_tokens=1, + total_pages=32, + retention=PagedCacheRetention.SlidingWindow, + sliding_window_tokens=8, + ), + ] + # Opt into the prefix-cache adjunct so the scheduler actually builds the + # snapshot chain; without this the attach loop is a no-op and the + # borrowed-prefix path under test never runs. The C++ side derives + # lcm_raw_tokens and sliding_window_per_group from each required group's + # PagedCacheGroupConfig; the Python ABI only declares the group ids. + adjunct = PrefixCacheAdjunctSpec() + adjunct.required_groups = ["fh", "swa"] + cfg.prefix_cache_adjunct = adjunct + return cfg + + +def _post(sched: "Scheduler", payload) -> None: + ev = ExecutionEvent() + ev.add_event(payload) + sched.advance(ev) + + +def _prime_r1(sched: "Scheduler") -> tuple[list[int], list[int]]: + """Prime r1 and capture page ids before finish releases its tables.""" + spec = RequestSpec() + spec.request_id = "r1" + spec.tokens = list(range(1, 13)) + sched.submit_requests([spec]) + sched.next_execution_plan() + er = ForwardEvent.ExtendResult() + er.request_id = "r1" + er.tokens = [99] + _post(sched, er) + sched.next_execution_plan() + r1_fh = list(sched.get_request_paged_cache_page_ids("r1", "fh")) + r1_swa = list(sched.get_request_paged_cache_page_ids("r1", "swa")) + fin = ForwardEvent.Finish() + fin.request_id = "r1" + _post(sched, fin) + sched.next_execution_plan() + return r1_fh, r1_swa + + +def _submit_r2_same_prefix(sched: "Scheduler") -> None: + spec = RequestSpec() + spec.request_id = "r2" + spec.tokens = list(range(1, 13)) + sched.submit_requests([spec]) + sched.next_execution_plan() + + +class TestV4PrefixCacheMetadata(unittest.TestCase): + + def setUp(self) -> None: + self.sched = Scheduler(_make_two_group_config()) + self.r1_fh, self.r1_swa = _prime_r1(self.sched) + self.assertNotEqual( + self.r1_fh, + [], + "r1 fh page ids must be captured before finish releases the request table", + ) + self.assertNotEqual( + self.r1_swa, + [], + "r1 swa page ids must be captured before finish releases the request table", + ) + _submit_r2_same_prefix(self.sched) + + def test_block_table_borrowed_plus_suffix(self) -> None: + """R2's table starts with r1's borrowed prefix page ids.""" + r2_fh = list(self.sched.get_request_paged_cache_page_ids("r2", "fh")) + self.assertGreaterEqual(len(r2_fh), 1) + n = min(len(self.r1_fh), len(r2_fh)) + self.assertGreater( + n, + 0, + "r2 must borrow at least one fh page from r1's prefix snapshot", + ) + self.assertEqual(r2_fh[:n], self.r1_fh[:n]) + + def test_base_offsets_sliding_correct(self) -> None: + """For sliding-window groups the per-request base_logical_page + matches the snapshot's base offset; full-history is always 0.""" + swa_base = self.sched.get_request_paged_cache_base_logical_page("r2", "swa") + self.assertGreaterEqual(swa_base, 0) + swa_ids = list(self.sched.get_request_paged_cache_page_ids("r2", "swa")) + self.assertLessEqual(swa_base, len(self.r1_swa) + len(swa_ids)) + fh_base = self.sched.get_request_paged_cache_base_logical_page("r2", "fh") + self.assertEqual(fh_base, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tokenspeed-scheduler/CMakeLists.txt b/tokenspeed-scheduler/CMakeLists.txt index 64b4e2ebe..8675e63a7 100644 --- a/tokenspeed-scheduler/CMakeLists.txt +++ b/tokenspeed-scheduler/CMakeLists.txt @@ -110,6 +110,11 @@ if(TOKENSPEED_SCHEDULER_BUILD_TESTS) tests/cpp/test_batch_scheduling.cpp tests/cpp/test_prefetch.cpp tests/cpp/test_owned_pages.cpp + tests/cpp/test_paged_cache_prefix_match.cpp + tests/cpp/test_paged_cache_attach_loop.cpp + tests/cpp/test_paged_cache_eviction.cpp + tests/cpp/test_paged_cache_family_split.cpp + tests/cpp/test_paged_cache_prefix_hit_commit.cpp tests/cpp/test_retract_abort_pages.cpp tests/cpp/test_mamba_slot.cpp tests/cpp/test_mamba_eviction.cpp diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index 12f5e7dc6..a10ac2b0b 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -150,6 +151,10 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .value("FullHistory", tokenspeed::PagedCacheGroupConfig::Retention::FullHistory) .value("SlidingWindow", tokenspeed::PagedCacheGroupConfig::Retention::SlidingWindow); + nb::enum_(m, "PagedCacheGroupFamily") + .value("History", tokenspeed::PagedCacheGroupFamily::History) + .value("State", tokenspeed::PagedCacheGroupFamily::State); + nb::class_(m, "PagedCacheGroupConfig") .def(nb::init<>()) .def( @@ -157,19 +162,22 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { [](tokenspeed::PagedCacheGroupConfig* self, std::string group_id, std::int32_t rows_per_page, std::int32_t entry_stride_tokens, std::int32_t total_pages, tokenspeed::PagedCacheGroupConfig::Retention retention, - std::optional sliding_window_tokens) { - new (self) tokenspeed::PagedCacheGroupConfig{std::move(group_id), rows_per_page, entry_stride_tokens, - total_pages, retention, sliding_window_tokens}; + std::optional sliding_window_tokens, tokenspeed::PagedCacheGroupFamily family) { + new (self) tokenspeed::PagedCacheGroupConfig{ + std::move(group_id), rows_per_page, entry_stride_tokens, total_pages, retention, + sliding_window_tokens, family}; }, nb::arg("group_id"), nb::arg("rows_per_page"), nb::arg("entry_stride_tokens"), nb::arg("total_pages"), nb::arg("retention") = tokenspeed::PagedCacheGroupConfig::Retention::FullHistory, - nb::arg("sliding_window_tokens") = std::nullopt) + nb::arg("sliding_window_tokens") = std::nullopt, + nb::arg("family") = tokenspeed::PagedCacheGroupFamily::History) .def_rw("group_id", &tokenspeed::PagedCacheGroupConfig::group_id) .def_rw("rows_per_page", &tokenspeed::PagedCacheGroupConfig::rows_per_page) .def_rw("entry_stride_tokens", &tokenspeed::PagedCacheGroupConfig::entry_stride_tokens) .def_rw("total_pages", &tokenspeed::PagedCacheGroupConfig::total_pages) .def_rw("retention", &tokenspeed::PagedCacheGroupConfig::retention) .def_rw("sliding_window_tokens", &tokenspeed::PagedCacheGroupConfig::sliding_window_tokens) + .def_rw("family", &tokenspeed::PagedCacheGroupConfig::family) .def("raw_tokens_per_page", &tokenspeed::PagedCacheGroupConfig::RawTokensPerPage) .def("validate", &tokenspeed::PagedCacheGroupConfig::Validate); @@ -192,6 +200,8 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def("page_ids", &tokenspeed::PagedCacheGroupTable::PageIds, nb::rv_policy::reference_internal) .def("size", &tokenspeed::PagedCacheGroupTable::Size) .def("active_pages_count", &tokenspeed::PagedCacheGroupTable::ActivePagesCount) + .def("owned_pages_count", &tokenspeed::PagedCacheGroupTable::OwnedPagesCount) + .def("borrowed_pages_count", &tokenspeed::PagedCacheGroupTable::BorrowedPagesCount) .def("released_pages_count", &tokenspeed::PagedCacheGroupTable::ReleasedPagesCount) .def("base_logical_page", &tokenspeed::PagedCacheGroupTable::BaseLogicalPage) .def("raw_token_cursor", &tokenspeed::PagedCacheGroupTable::RawTokenCursor) @@ -201,6 +211,12 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def("is_sliding", &tokenspeed::PagedCacheGroupTable::IsSliding) .def("sliding_window_tokens", &tokenspeed::PagedCacheGroupTable::SlidingWindowTokens); + // Python declares the required group ids only. Scheduler derives LCM and + // sliding-window metadata from the matching PagedCacheGroupConfig entries. + nb::class_(m, "PrefixCacheAdjunctSpec") + .def(nb::init<>()) + .def_rw("required_groups", &tokenspeed::PrefixCacheAdjunctSpec::required_groups); + scheduler_config.def(nb::init<>()) .def_rw("page_size", &tokenspeed::SchedulerConfig::page_size) .def_rw("max_scheduled_tokens", &tokenspeed::SchedulerConfig::max_scheduled_tokens) @@ -214,6 +230,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { "num_host_pages", [](const tokenspeed::SchedulerConfig& c) { return c.host_allocator.total_pages; }, [](tokenspeed::SchedulerConfig& c, std::int32_t v) { c.host_allocator.total_pages = v; }) .def_rw("paged_cache_groups", &tokenspeed::SchedulerConfig::paged_cache_groups) + .def_rw("prefix_cache_adjunct", &tokenspeed::SchedulerConfig::prefix_cache_adjunct) .def_rw("disable_l2_cache", &tokenspeed::SchedulerConfig::disable_l2_cache) .def_rw("enable_l3_storage", &tokenspeed::SchedulerConfig::enable_l3_storage) .def_rw("prefetch_threshold", &tokenspeed::SchedulerConfig::prefetch_threshold) diff --git a/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.cpp b/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.cpp index d32d69da3..ccda15125 100644 --- a/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.cpp +++ b/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.cpp @@ -24,16 +24,9 @@ #include #include -namespace tokenspeed { - -namespace { - -std::int32_t CeilDivPositive(std::int32_t numer, std::int32_t denom) { - if (numer <= 0) return 0; - return (numer + denom - 1) / denom; -} +#include "resource/types.h" -} // namespace +namespace tokenspeed { void PagedCacheGroupConfig::Validate() const { if (group_id.empty()) { @@ -84,6 +77,14 @@ void PagedCacheGroupAllocator::Deallocate(const std::vector& pages released_pages_total_ += static_cast(pages.size()); } +void PagedCacheGroupTable::RefreshPageIdsView() { + page_ids_view_.clear(); + page_ids_view_.reserve(borrowed_page_ids_.size() + static_cast(owned_pages_.Size())); + page_ids_view_.insert(page_ids_view_.end(), borrowed_page_ids_.begin(), borrowed_page_ids_.end()); + const auto& owned_ids = owned_pages_.Ids(); + page_ids_view_.insert(page_ids_view_.end(), owned_ids.begin(), owned_ids.end()); +} + void PagedCacheGroupTable::Acquire(std::int32_t target_raw_tokens_exclusive) { if (allocator_ == nullptr) { throw std::logic_error("PagedCacheGroupTable::Acquire: no allocator bound"); @@ -98,24 +99,175 @@ void PagedCacheGroupTable::Acquire(std::int32_t target_raw_tokens_exclusive) { const auto& cfg = allocator_->Config(); const std::int32_t entries = CeilDivPositive(target_raw_tokens_exclusive, cfg.entry_stride_tokens); const std::int32_t pages_needed = (entries + cfg.rows_per_page - 1) / cfg.rows_per_page; - // Absolute pages already covered = base + live size after any - // ReleaseSkipped compaction. Allocate only the delta. - const std::int32_t pages_have = base_logical_page_ + Size(); + // Absolute pages have = base + borrowed + owned; allocate only the delta. + const std::int32_t pages_have = + base_logical_page_ + static_cast(borrowed_page_ids_.size()) + owned_pages_.Size(); const std::int32_t pages_to_allocate = pages_needed - pages_have; if (pages_to_allocate > 0) { OwnedPages fresh = allocator_->AcquireOwned(pages_to_allocate); if (fresh.Size() < pages_to_allocate) { - // fresh dtor returns any partial allocation to pool_. throw std::runtime_error("PagedCacheGroupTable::Acquire: failed to allocate pages for group " + cfg.group_id); } - pages_.Append(std::move(fresh)); + owned_pages_.Append(std::move(fresh)); + RefreshPageIdsView(); } raw_token_cursor_ = target_raw_tokens_exclusive; } +PagedCacheGroupTable::CommitResult PagedCacheGroupTable::CommitHistoryToSnapshot(std::int32_t target_raw_tokens) { + if (allocator_ == nullptr) { + throw std::logic_error("PagedCacheGroupTable::CommitHistoryToSnapshot: no allocator bound"); + } + const auto& cfg = allocator_->Config(); + if (cfg.family != PagedCacheGroupFamily::History) { + throw std::logic_error("PagedCacheGroupTable::CommitHistoryToSnapshot: requires History family; group=" + + cfg.group_id); + } + if (target_raw_tokens <= committed_prefix_len_tokens_) { + return {}; + } + if (target_raw_tokens > raw_token_cursor_) { + throw std::invalid_argument( + "PagedCacheGroupTable::CommitHistoryToSnapshot: target exceeds raw_token_cursor; target=" + + std::to_string(target_raw_tokens) + "; cursor=" + std::to_string(raw_token_cursor_)); + } + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (raw_per_page <= 0) { + throw std::logic_error( + "PagedCacheGroupTable::CommitHistoryToSnapshot: invalid group config (raw_per_page <= 0)"); + } + if (committed_prefix_len_tokens_ % raw_per_page != 0) { + throw std::logic_error( + "PagedCacheGroupTable::CommitHistoryToSnapshot: committed cursor not page-aligned; committed=" + + std::to_string(committed_prefix_len_tokens_) + "; raw_per_page=" + std::to_string(raw_per_page)); + } + if (target_raw_tokens % raw_per_page != 0) { + throw std::invalid_argument("PagedCacheGroupTable::CommitHistoryToSnapshot: target not page-aligned; target=" + + std::to_string(target_raw_tokens) + + "; raw_per_page=" + std::to_string(raw_per_page)); + } + + const std::int32_t pages_to_commit = (target_raw_tokens - committed_prefix_len_tokens_) / raw_per_page; + if (pages_to_commit <= 0) { + committed_prefix_len_tokens_ = target_raw_tokens; + RefreshPageIdsView(); + return {}; + } + if (pages_to_commit > owned_pages_.Size()) { + throw std::logic_error("PagedCacheGroupTable::CommitHistoryToSnapshot: not enough owned pages; want=" + + std::to_string(pages_to_commit) + "; have_owned=" + std::to_string(owned_pages_.Size())); + } + + const std::int32_t segment_base_logical_page = committed_prefix_len_tokens_ / raw_per_page; + OwnedPages segment = owned_pages_.TakeFirst(pages_to_commit); + const auto& seg_ids = segment.Ids(); + borrowed_page_ids_.insert(borrowed_page_ids_.end(), seg_ids.begin(), seg_ids.end()); + committed_prefix_len_tokens_ = target_raw_tokens; + RefreshPageIdsView(); + return CommitResult{std::move(segment), segment_base_logical_page}; +} + +PagedCacheGroupTable::CommitResult PagedCacheGroupTable::CheckpointStateToSnapshot(std::int32_t target_raw_tokens) { + if (allocator_ == nullptr) { + throw std::logic_error("PagedCacheGroupTable::CheckpointStateToSnapshot: no allocator bound"); + } + const auto& cfg = allocator_->Config(); + if (cfg.family != PagedCacheGroupFamily::State) { + throw std::logic_error("PagedCacheGroupTable::CheckpointStateToSnapshot: requires State family; group=" + + cfg.group_id); + } + if (!cfg.sliding_window_tokens.has_value() || *cfg.sliding_window_tokens <= 0) { + throw std::logic_error( + "PagedCacheGroupTable::CheckpointStateToSnapshot: State family requires positive" + " sliding_window_tokens; group=" + + cfg.group_id); + } + if (target_raw_tokens <= committed_prefix_len_tokens_) { + return {}; + } + if (target_raw_tokens > raw_token_cursor_) { + throw std::invalid_argument( + "PagedCacheGroupTable::CheckpointStateToSnapshot: target exceeds raw_token_cursor; target=" + + std::to_string(target_raw_tokens) + "; cursor=" + std::to_string(raw_token_cursor_)); + } + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (raw_per_page <= 0) { + throw std::logic_error( + "PagedCacheGroupTable::CheckpointStateToSnapshot: invalid group config (raw_per_page <= 0)"); + } + if (target_raw_tokens % raw_per_page != 0) { + throw std::invalid_argument( + "PagedCacheGroupTable::CheckpointStateToSnapshot: target not page-aligned; target=" + + std::to_string(target_raw_tokens) + "; raw_per_page=" + std::to_string(raw_per_page)); + } + + const std::int32_t window = *cfg.sliding_window_tokens; + const std::int32_t live_lower_raw = std::max(0, target_raw_tokens - window); + const std::int32_t live_lower_page = live_lower_raw / raw_per_page; + + // Drop stale borrowed entries: their physical pages live on earlier snapshots, + // so the table just discards the index. Mirrors ReleaseSkipped semantics. + if (live_lower_page > base_logical_page_) { + const std::int32_t borrowed_to_drop = std::min( + live_lower_page - base_logical_page_, static_cast(borrowed_page_ids_.size())); + if (borrowed_to_drop > 0) { + borrowed_page_ids_.erase(borrowed_page_ids_.begin(), borrowed_page_ids_.begin() + borrowed_to_drop); + base_logical_page_ += borrowed_to_drop; + } + } + + // Drop stale owned-prefix pages (below live_lower) back to the pool via RAII. + if (live_lower_page > base_logical_page_) { + const std::int32_t owned_dead_pages = + std::min(live_lower_page - base_logical_page_, owned_pages_.Size()); + if (owned_dead_pages > 0) { + OwnedPages dropped = owned_pages_.TakeFirst(owned_dead_pages); + base_logical_page_ += owned_dead_pages; + // dropped dtor returns pages to pool. + } + } + + // Snapshot stores ONLY this commit step's owned delta (the new LCM segment's + // pages). The trailing window is reconstructed across the chain at match + // time by augmentMatchPagedCache::assemble; borrowed pages still belong to + // upstream snapshots and must not be re-stored here. base is the absolute + // logical page where the delta starts: post-Step-1/2 base + remaining borrowed. + OwnedPages segment = owned_pages_.TakeFirst(owned_pages_.Size()); + const std::int32_t segment_base_logical_page = + base_logical_page_ + static_cast(borrowed_page_ids_.size()); + const auto& seg_ids = segment.Ids(); + borrowed_page_ids_.insert(borrowed_page_ids_.end(), seg_ids.begin(), seg_ids.end()); + + committed_prefix_len_tokens_ = target_raw_tokens; + RefreshPageIdsView(); + return CommitResult{std::move(segment), segment_base_logical_page}; +} + +void PagedCacheGroupTable::ImportPrefixBorrowed(std::vector ids, std::int32_t base_logical_page, + std::int32_t raw_tokens_covered) { + if (allocator_ == nullptr) { + throw std::logic_error("PagedCacheGroupTable::ImportPrefixBorrowed: no allocator bound"); + } + if (!(borrowed_page_ids_.empty() && owned_pages_.Empty() && raw_token_cursor_ == 0 && base_logical_page_ == 0 && + committed_prefix_len_tokens_ == 0)) { + throw std::logic_error("PagedCacheGroupTable::ImportPrefixBorrowed: only legal on a fresh-empty table"); + } + if (base_logical_page < 0) { + throw std::invalid_argument("PagedCacheGroupTable::ImportPrefixBorrowed: base_logical_page must be >= 0"); + } + if (raw_tokens_covered < 0) { + throw std::invalid_argument("PagedCacheGroupTable::ImportPrefixBorrowed: raw_tokens_covered must be >= 0"); + } + borrowed_page_ids_ = std::move(ids); + base_logical_page_ = base_logical_page; + raw_token_cursor_ = raw_tokens_covered; + committed_prefix_len_tokens_ = raw_tokens_covered; + RefreshPageIdsView(); +} + std::vector PagedCacheGroupTable::ReleaseSkipped(std::int32_t window_lower_bound) { - if (allocator_ == nullptr || pages_.Empty() || window_lower_bound <= 0) { + if (allocator_ == nullptr || Size() == 0 || window_lower_bound <= 0) { return {}; } const auto& cfg = allocator_->Config(); @@ -126,8 +278,6 @@ std::vector PagedCacheGroupTable::ReleaseSkipped(std::int32_t wind if (raw_per_page <= 0) { return {}; } - // Absolute logical-page index (exclusive) below which entries fall out of - // the active window. const std::int32_t target = window_lower_bound / raw_per_page; if (target <= base_logical_page_) { return {}; @@ -136,18 +286,45 @@ std::vector PagedCacheGroupTable::ReleaseSkipped(std::int32_t wind if (to_drop <= 0) { return {}; } - OwnedPages dropped = pages_.TakeFirst(to_drop); - std::vector released = dropped.Ids(); + std::vector released; + released.reserve(static_cast(to_drop)); + + // Drop from FRONT: borrowed first, then owned. + const std::int32_t borrowed_drop = std::min(to_drop, static_cast(borrowed_page_ids_.size())); + if (borrowed_drop > 0) { + released.insert(released.end(), borrowed_page_ids_.begin(), borrowed_page_ids_.begin() + borrowed_drop); + borrowed_page_ids_.erase(borrowed_page_ids_.begin(), borrowed_page_ids_.begin() + borrowed_drop); + // Borrowed pages stay owned by their TreeNode snapshot; only shrink index. + } + const std::int32_t owned_drop = to_drop - borrowed_drop; + if (owned_drop > 0) { + OwnedPages dropped = owned_pages_.TakeFirst(owned_drop); + const auto& dropped_ids = dropped.Ids(); + released.insert(released.end(), dropped_ids.begin(), dropped_ids.end()); + // dropped dtor returns pages to pool. + } base_logical_page_ += to_drop; - // dropped goes out of scope: OwnedPages dtor returns the pages to pool_. + RefreshPageIdsView(); + // committed_prefix_len_tokens_ intentionally untouched (logical vs physical). return released; } std::vector PagedCacheGroupTable::ReleaseAll() { - OwnedPages dropped = pages_.TakeFirst(pages_.Size()); - std::vector released = dropped.Ids(); + std::vector released; + released.reserve(borrowed_page_ids_.size() + static_cast(owned_pages_.Size())); + released.insert(released.end(), borrowed_page_ids_.begin(), borrowed_page_ids_.end()); + borrowed_page_ids_.clear(); + // Borrowed pages stay owned by their snapshots. + + OwnedPages dropped = owned_pages_.TakeFirst(owned_pages_.Size()); + const auto& dropped_ids = dropped.Ids(); + released.insert(released.end(), dropped_ids.begin(), dropped_ids.end()); + // dropped dtor returns pages to pool. + raw_token_cursor_ = 0; base_logical_page_ = 0; + committed_prefix_len_tokens_ = 0; + RefreshPageIdsView(); return released; } diff --git a/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.h b/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.h index efc5a44ee..2dc562ce3 100644 --- a/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.h +++ b/tokenspeed-scheduler/csrc/resource/allocator/paged_cache_group.h @@ -30,8 +30,23 @@ namespace tokenspeed { -// One model-defined paged cache group. The scheduler treats group_id as opaque: -// V4 uses ids like "v4.c4a.compressed_kv" and "v4.swa_kv". +// Positive-only ceiling division; returns 0 for non-positive numerators. +// Lives here because paged-cache admission/table math is its only caller. +inline std::int32_t CeilDivPositive(std::int32_t numer, std::int32_t denom) { + if (numer <= 0) return 0; + return (numer + denom - 1) / denom; +} + +// Paged-cache group families for V4 prefix-cache reuse. +// History: every page on [0, P) required (chain). +// State: only the trailing window at the hit depth required. +enum class PagedCacheGroupFamily { History, State }; + +// Phase 1 only has kSnapshotRequired; kReplayTail / kCheckpointThenReplay +// reserved for Phase 2 RecoveryPlan work. +enum class StateRestorePolicy { kSnapshotRequired }; + +// One model-defined paged cache group; scheduler treats group_id as opaque. struct PagedCacheGroupConfig { enum class Retention { FullHistory, @@ -44,16 +59,16 @@ struct PagedCacheGroupConfig { std::int32_t total_pages{}; Retention retention{Retention::FullHistory}; std::optional sliding_window_tokens{}; + // History groups form a chain; State groups only need the trailing window. + PagedCacheGroupFamily family{PagedCacheGroupFamily::History}; std::int32_t RawTokensPerPage() const { return rows_per_page * entry_stride_tokens; } void Validate() const; }; -// Group-level allocator. Composes a PageAllocator that owns the free list, and -// adds group config + cumulative counters. PagedCacheGroupTable acquires pages -// through this allocator and stores them as OwnedPages bound to the inner -// pool, so all release paths go directly to the pool via RAII. +// Group-level allocator: wraps PageAllocator + config + counters. Releases run +// via OwnedPages RAII directly to the pool. class PagedCacheGroupAllocator { public: explicit PagedCacheGroupAllocator(PagedCacheGroupConfig config); @@ -71,17 +86,14 @@ class PagedCacheGroupAllocator { std::int32_t AvailablePages() const { return pool_.AvailablePages(); } std::int64_t AllocatedPagesTotal() const { return allocated_pages_total_; } - // Counts pages explicitly returned via Deallocate(). Pages released through - // PagedCacheGroupTable RAII (destructor / ReleaseSkipped / ReleaseAll) go - // directly to the inner pool and are not counted here. + // Only counts explicit Deallocate(); RAII releases bypass this counter. std::int64_t ReleasedPagesTotal() const { return released_pages_total_; } std::int64_t FailedAllocCount() const { return failed_alloc_count_; } private: friend class PagedCacheGroupTable; - // Allocate a fresh batch as OwnedPages bound to pool_. Bumps stats; returns - // an empty OwnedPages on insufficient capacity (and bumps failed counter). + // Empty OwnedPages on insufficient capacity (bumps failed_alloc_count_). OwnedPages AcquireOwned(std::int32_t num_pages); PagedCacheGroupConfig config_; @@ -91,12 +103,13 @@ class PagedCacheGroupAllocator { std::int64_t failed_alloc_count_{0}; }; -// One per request, per group. Stores live pages as OwnedPages so destruction -// (and any TakeFirst-based release) automatically returns them to the pool. -// Acquire grows the cursor; ReleaseSkipped peels expired pages off the front -// (sliding-window groups only) and bumps the base logical page index so -// PageIds() always exposes only live entries; absolute logical page for -// column c is BaseLogicalPage() + c. +// One per request, per group. Two storage segments (no refcounts): +// - `borrowed_page_ids_`: page ids only; physical ownership lives in a +// TreeNode's PagedCacheSnapshot, pinned via the request's DeviceNodeRef. +// - `owned_pages_`: RAII back to the allocator on release or moved to a +// snapshot via CommitHistoryToSnapshot / CheckpointStateToSnapshot. +// PageIds() = borrowed ++ owned, where column c == absolute logical page +// BaseLogicalPage() + c. ReleaseSkipped peels expired front pages (sliding only). class PagedCacheGroupTable { public: PagedCacheGroupTable() = default; @@ -108,28 +121,52 @@ class PagedCacheGroupTable { PagedCacheGroupTable(PagedCacheGroupTable&&) noexcept = default; PagedCacheGroupTable& operator=(PagedCacheGroupTable&&) noexcept = default; + // Grow pages to cover [base*RawTokensPerPage, target_raw_tokens_exclusive). void Acquire(std::int32_t target_raw_tokens_exclusive); - // Returns physical ids of pages whose covered raw range is strictly below - // `window_lower_bound`. Compacts the in-memory table and bumps the base - // logical page so PageIds() always contains only live entries. Idempotent. - // No-op (returns empty) for full-history groups. + // segment_base_logical_page is captured BEFORE the commit cursor advances + // (sliding ReleaseSkipped may have already moved BaseLogicalPage() forward). + struct CommitResult { + OwnedPages pages; + std::int32_t segment_base_logical_page{0}; + }; + + // History append: move owned [committed, target) out and mirror ids to + // borrowed_page_ids_. Throws for non-History family groups. + CommitResult CommitHistoryToSnapshot(std::int32_t target_raw_tokens); + + // State checkpoint: snapshot the live trailing window [max(0,target-W), + // target); drop stale prefix from both owned (back to pool) and borrowed + // (index drop only; physical pages live on earlier snapshots). Throws for + // non-State family groups or when sliding_window_tokens is missing/non-positive. + CommitResult CheckpointStateToSnapshot(std::int32_t target_raw_tokens); + + // Adopt borrowed page ids from a prefix-cache hit on a fresh-empty table. + // Throws std::logic_error if called after Acquire/Import/Commit. + void ImportPrefixBorrowed(std::vector ids, std::int32_t base_logical_page, + std::int32_t raw_tokens_covered); + + // Sliding-only: drop front pages strictly below `window_lower_bound`. + // Bumps base_logical_page_; commit cursor untouched. Idempotent. std::vector ReleaseSkipped(std::int32_t window_lower_bound); - // Returns all live physical ids and clears the table. Used by - // finish/abort/retract. + // Release everything; owned via RAII, borrowed by clearing. Used by finish/abort/retract. std::vector ReleaseAll(); - // Compact view: column c here represents absolute logical page - // BaseLogicalPage() + c. Released-from-front pages are NOT present. - const std::vector& PageIds() const { return pages_.Ids(); } - std::int32_t Size() const { return pages_.Size(); } + // Compact: PageIds()[c] = absolute logical page BaseLogicalPage() + c. + const std::vector& PageIds() const { return page_ids_view_; } + std::int32_t Size() const { return static_cast(borrowed_page_ids_.size()) + owned_pages_.Size(); } std::int32_t ActivePagesCount() const { return Size(); } + std::int32_t OwnedPagesCount() const { return owned_pages_.Size(); } + std::int32_t BorrowedPagesCount() const { return static_cast(borrowed_page_ids_.size()); } std::int32_t ReleasedPagesCount() const { return base_logical_page_; } std::int32_t BaseLogicalPage() const { return base_logical_page_; } std::int32_t RawTokenCursor() const { return raw_token_cursor_; } - bool IsEmpty() const { return allocator_ == nullptr || pages_.Empty(); } + // Independent of base_logical_page_; sliding ReleaseSkipped does not move this. + std::int32_t CommittedPrefixLenTokens() const { return committed_prefix_len_tokens_; } + + bool IsEmpty() const { return allocator_ == nullptr || Size() == 0; } std::int32_t RowsPerPage() const; std::int32_t EntryStrideTokens() const; std::int32_t RawTokensPerPage() const; @@ -137,11 +174,17 @@ class PagedCacheGroupTable { std::int32_t SlidingWindowTokens() const; private: + // Must be called after every mutation of borrowed_page_ids_ or owned_pages_. + void RefreshPageIdsView(); + PagedCacheGroupAllocator* allocator_{nullptr}; - OwnedPages pages_; + OwnedPages owned_pages_; + std::vector borrowed_page_ids_; std::int32_t raw_token_cursor_{0}; - // Absolute logical-page index of pages_.Ids()[0]. Bumped by ReleaseSkipped. std::int32_t base_logical_page_{0}; + std::int32_t committed_prefix_len_tokens_{0}; + // Cached borrowed ++ owned, exposed by PageIds() as const ref for ABI shape. + std::vector page_ids_view_; }; } // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp index 4157c6b1c..4e6599bb6 100644 --- a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp @@ -20,9 +20,21 @@ #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" #include "resource/allocator/mamba_chunk_allocator.h" +#include "resource/allocator/paged_cache_group.h" +#include "resource/radix_tree/paged_cache_snapshot.h" +#include "resource/radix_tree/radix_tree.h" +#include "resource/radix_tree/tree_node.h" +#include "scheduler/operations/forward.h" +#include "utils.h" +#include + +#include #include +#include +#include #include +#include namespace tokenspeed { @@ -33,11 +45,10 @@ HybridPrefixCache::HybridPrefixCache(KVPrefixCache& kv_prefix_cache, MambaChunkA mamba_eviction_manager_{mamba_allocator}, mamba_cache_chunk_size_{mamba_cache_chunk_size} {} -// The HybridPrefixCache::Match is a wrapper of the KVPrefixCache::Match. -// And it's not used in fact, because the HybridPrefixCache is always used with the KVPrefixCache now. MatchResult HybridPrefixCache::Match(const token_vec_t& token_ids, MatchIntent intent) { auto match = kv_prefix_cache_.Match(token_ids, intent); augmentMatch(match); + augmentMatchPagedCache(match); return match; } @@ -45,10 +56,12 @@ MatchResult HybridPrefixCache::Match(const std::vectorIsRoot()) return; @@ -96,11 +109,15 @@ TreeNode* HybridPrefixCache::FindLastMambaNode(TreeNode* from) const { } bool HybridPrefixCache::EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node) { + if (mamba_allocator_ == nullptr) return num_slots <= 0; return mamba_eviction_manager_.EnsureCapacity(num_slots, protected_node); } void HybridPrefixCache::InsertMamba(TreeNode* terminal_node, std::unique_ptr slot) { if (terminal_node == nullptr || slot == nullptr) return; + if (mamba_allocator_ == nullptr) { + throw std::logic_error("HybridPrefixCache::InsertMamba: mamba adjunct not enabled"); + } const std::int32_t page_size = kv_prefix_cache_.PageSize(); if (page_size <= 0 || terminal_node->DepthInTokens() % static_cast(page_size) != 0) { throw std::logic_error("HybridPrefixCache::InsertMamba: terminal node is not block-aligned"); @@ -109,17 +126,856 @@ void HybridPrefixCache::InsertMamba(TreeNode* terminal_node, std::unique_ptr snapshot) { + if (node == nullptr || snapshot == nullptr) return false; + // Compute completeness from what is present. The policy-driven "snapshot + // must be full" invariant is enforced upstream by CommitChunk, which only + // attaches full snapshots; direct callers (tests, future restore paths) + // may attach history-only or state-only snapshots without policy gating. + snapshot->complete_families.clear(); + bool history_complete = !paged_cache_history_groups_.empty(); + for (const auto& gid : paged_cache_history_groups_) { + if (snapshot->groups.find(gid) == snapshot->groups.end()) { + history_complete = false; + break; + } + } + if (history_complete) { + snapshot->complete_families.insert(PagedCacheGroupFamily::History); + } + bool state_complete = !paged_cache_state_groups_.empty(); + for (const auto& gid : paged_cache_state_groups_) { + if (snapshot->groups.find(gid) == snapshot->groups.end()) { + state_complete = false; + break; + } + } + if (state_complete) { + snapshot->complete_families.insert(PagedCacheGroupFamily::State); + } + node->AttachPagedCacheSnapshot(std::move(snapshot)); + paged_cache_snapshot_nodes_.insert(node); + return true; +} + +std::unique_ptr HybridPrefixCache::DetachPagedCacheSnapshotFromNode(TreeNode* node) { + if (node == nullptr) return nullptr; + paged_cache_snapshot_nodes_.erase(node); + return node->DetachPagedCacheSnapshot(); +} + void HybridPrefixCache::OnKVEvict(TreeNode* node) { - if (node == nullptr || !node->HasMamba()) return; - mamba_eviction_manager_.UntrackNode(node); - node->DetachMamba(); - if (node->Parent() != nullptr) { - mamba_eviction_manager_.UpdateLeaf(node->Parent()); + if (node == nullptr) return; + if (mamba_allocator_ != nullptr && node->HasMamba()) { + mamba_eviction_manager_.UntrackNode(node); + node->DetachMamba(); + if (node->Parent() != nullptr) { + mamba_eviction_manager_.UpdateLeaf(node->Parent()); + } + } + // Passive paged-cache detach on KV LRU drop: returns OwnedPages via RAII; + // the chain scan sees the gap because `HasPagedCacheSnapshot()` is false. + // Route through DetachPagedCacheSnapshotFromNode to keep membership set in sync. + if (node->HasPagedCacheSnapshot()) { + DetachPagedCacheSnapshotFromNode(node); } } std::int32_t HybridPrefixCache::AvailableSlots() const { + if (mamba_allocator_ == nullptr) return 0; return mamba_allocator_->AvailableSlots(); } +void HybridPrefixCache::RegisterPagedCacheGroup(std::unique_ptr allocator) { + if (allocator == nullptr) { + throw std::invalid_argument("HybridPrefixCache::RegisterPagedCacheGroup: null allocator"); + } + std::string gid = allocator->Config().group_id; + if (paged_cache_allocators_.find(gid) != paged_cache_allocators_.end()) { + throw std::invalid_argument("HybridPrefixCache::RegisterPagedCacheGroup: duplicate group_id: " + gid); + } + paged_cache_allocators_.emplace(std::move(gid), std::move(allocator)); +} + +void HybridPrefixCache::EnablePagedCacheAdjunct(std::vector required_groups, + std::unordered_map sliding_window_per_group, + StateRestorePolicy policy) { + if (required_groups.empty()) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required_groups must be non-empty"); + } + std::vector history_gids; + std::vector state_gids; + std::vector required_sliding_gids; + history_gids.reserve(required_groups.size()); + state_gids.reserve(required_groups.size()); + required_sliding_gids.reserve(required_groups.size()); + + // Partition required groups by family; collect sliding-group entries for + // post-validation against `sliding_window_per_group`. + for (const auto& gid : required_groups) { + auto it = paged_cache_allocators_.find(gid); + if (it == paged_cache_allocators_.end() || it->second == nullptr) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required group '" + gid + + "' missing from registered allocators"); + } + const auto& cfg = it->second->Config(); + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (raw_per_page <= 0) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required group '" + gid + + "' has non-positive RawTokensPerPage"); + } + if (cfg.family == PagedCacheGroupFamily::History) { + history_gids.push_back(gid); + } else { + state_gids.push_back(gid); + } + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + auto win_it = sliding_window_per_group.find(gid); + if (win_it == sliding_window_per_group.end() || win_it->second <= 0) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: sliding group '" + gid + + "' missing positive sliding_window entry"); + } + required_sliding_gids.push_back(gid); + } + } + if (history_gids.empty()) { + throw std::invalid_argument( + "HybridPrefixCache::EnablePagedCacheAdjunct: at least one History-family group required"); + } + if (sliding_window_per_group.size() != required_sliding_gids.size()) { + throw std::invalid_argument( + "HybridPrefixCache::EnablePagedCacheAdjunct: sliding_window_per_group keys must exactly " + "match the set of required groups whose retention is SlidingWindow"); + } + + // History alignment = LCM(raw_per_page) across History-family groups. + std::int32_t history_alignment = 1; + for (const auto& gid : history_gids) { + const auto& cfg = paged_cache_allocators_.find(gid)->second->Config(); + history_alignment = std::lcm(history_alignment, cfg.RawTokensPerPage()); + } + // Phase 1: state groups must align with the history alignment (so trailing + // segments are themselves page-aligned). Phase 2 will relax this via replay. + if (policy == StateRestorePolicy::kSnapshotRequired) { + for (const auto& gid : state_gids) { + const auto& cfg = paged_cache_allocators_.find(gid)->second->Config(); + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (history_alignment % raw_per_page != 0) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: state group '" + gid + + "' RawTokensPerPage=" + std::to_string(raw_per_page) + + " does not divide history_alignment=" + std::to_string(history_alignment)); + } + } + } + + paged_cache_history_alignment_tokens_ = history_alignment; + paged_cache_required_groups_ = std::move(required_groups); + paged_cache_sliding_window_per_group_ = std::move(sliding_window_per_group); + paged_cache_history_groups_ = std::move(history_gids); + paged_cache_state_groups_ = std::move(state_gids); + paged_cache_history_group_set_ = + std::unordered_set(paged_cache_history_groups_.begin(), paged_cache_history_groups_.end()); + paged_cache_state_group_set_ = + std::unordered_set(paged_cache_state_groups_.begin(), paged_cache_state_groups_.end()); + paged_cache_state_policy_ = policy; +} + +namespace { + +// Ancestor path (excluding root), reversed so element 0 is closest to root. +std::vector CollectAncestorPathRootToLeaf(TreeNode* from) { + std::vector path; + for (TreeNode* n = from; n != nullptr && !n->IsRoot(); n = n->Parent()) { + path.push_back(n); + } + std::reverse(path.begin(), path.end()); + return path; +} + +} // namespace + +void HybridPrefixCache::augmentMatchPagedCache(MatchResult& match) const { + if (!HasPagedCacheAdjunct()) return; + if (match.device.last_node == nullptr) return; + + const std::int32_t align = paged_cache_history_alignment_tokens_; + + auto cap_to_root = [&]() { + TreeNode* root = match.device.last_node; + while (root != nullptr && !root->IsRoot()) root = root->Parent(); + match.device.last_node = root; + if (match.host.last_node != nullptr) { + TreeNode* h = match.host.last_node; + while (h != nullptr && !h->IsRoot()) h = h->Parent(); + match.host.last_node = h; + } + }; + + std::vector path = CollectAncestorPathRootToLeaf(match.device.last_node); + + // Phase A: history chain. Walk root→leaf, advance only on contiguous + // History-family completeness at every k*align boundary. + TreeNode* deepest_history = nullptr; + std::vector history_chain; + std::int32_t expected_depth = align; + for (TreeNode* n : path) { + const std::int32_t d = static_cast(n->DepthInTokens()); + if (d < expected_depth) continue; + if (d > expected_depth) break; + const auto* snap = n->GetPagedCacheSnapshot(); + if (snap == nullptr) break; + if (!snap->IsCompleteFor(PagedCacheGroupFamily::History)) break; + deepest_history = n; + history_chain.push_back(n); + expected_depth += align; + } + if (deepest_history == nullptr) { + cap_to_root(); + return; + } + + // Phase B: state window. `segments_needed` is the worst-case trailing + // coverage across state groups (so every state group is satisfied at the + // chosen depth). Walk back through history_chain, pick the deepest D' + // whose trailing `segments_needed` history_chain entries all have State + // complete. + std::int32_t worst_window = 0; + for (const auto& gid : paged_cache_state_groups_) { + auto it = paged_cache_sliding_window_per_group_.find(gid); + if (it != paged_cache_sliding_window_per_group_.end()) { + worst_window = std::max(worst_window, it->second); + } + } + const std::int32_t segments_needed = worst_window > 0 ? (worst_window + align - 1) / align : 1; + + TreeNode* usable_node = nullptr; + if (paged_cache_state_groups_.empty()) { + usable_node = deepest_history; + } else { + for (std::int32_t end_idx = static_cast(history_chain.size()) - 1; end_idx >= 0; --end_idx) { + const std::int32_t start_idx = std::max(0, end_idx - segments_needed + 1); + bool ok = true; + for (std::int32_t i = start_idx; i <= end_idx; ++i) { + const auto* snap = history_chain[i]->GetPagedCacheSnapshot(); + if (snap == nullptr || !snap->IsCompleteFor(PagedCacheGroupFamily::State)) { + ok = false; + break; + } + } + if (ok) { + usable_node = history_chain[end_idx]; + break; + } + } + } + if (usable_node == nullptr) { + cap_to_root(); + return; + } + + const std::int32_t usable = static_cast(usable_node->DepthInTokens()); + // Trim history_chain to ancestors up to and including usable_node. + while (!history_chain.empty() && static_cast(history_chain.back()->DepthInTokens()) > usable) { + history_chain.pop_back(); + } + + // Phase C: per-group page-id assembly. History groups take the full chain; + // State groups share a trailing-window slice computed once. + match.paged_cache.last_node = usable_node; + match.paged_cache.prefix_len_tokens = usable; + match.paged_cache.per_group_page_ids.clear(); + match.paged_cache.per_group_base_logical_page.clear(); + + auto assemble = [&](const std::string& gid, std::span chain, bool is_sliding) { + std::vector page_ids; + std::int32_t base_logical_page = 0; + if (!chain.empty()) { + const PagedCacheSnapshot* earliest_snap = chain.front()->GetPagedCacheSnapshot(); + if (earliest_snap != nullptr && is_sliding) { + auto git = earliest_snap->groups.find(gid); + if (git != earliest_snap->groups.end()) { + base_logical_page = git->second.base_logical_page; + } + } + for (TreeNode* anc : chain) { + const PagedCacheSnapshot* snap = anc->GetPagedCacheSnapshot(); + if (snap == nullptr) continue; + auto git = snap->groups.find(gid); + if (git == snap->groups.end()) continue; + const auto& seg_ids = git->second.pages.Ids(); + page_ids.insert(page_ids.end(), seg_ids.begin(), seg_ids.end()); + } + } + match.paged_cache.per_group_page_ids[gid] = std::move(page_ids); + match.paged_cache.per_group_base_logical_page[gid] = base_logical_page; + }; + + const std::span history_span{history_chain}; + for (const auto& gid : paged_cache_history_groups_) { + const bool is_sliding = + paged_cache_sliding_window_per_group_.find(gid) != paged_cache_sliding_window_per_group_.end(); + assemble(gid, history_span, is_sliding); + } + if (!paged_cache_state_groups_.empty()) { + const std::size_t take = std::min(history_chain.size(), static_cast(segments_needed)); + const std::span state_span = history_span.last(take); + for (const auto& gid : paged_cache_state_groups_) { + const bool is_sliding = + paged_cache_sliding_window_per_group_.find(gid) != paged_cache_sliding_window_per_group_.end(); + assemble(gid, state_span, is_sliding); + } + } + + // Cap device/host match nodes to the paged-cache usable depth. + match.device.last_node = usable_node; + if (match.host.last_node != nullptr && static_cast(match.host.last_node->DepthInTokens()) > usable) { + TreeNode* h = match.host.last_node; + while (h != nullptr && !h->IsRoot() && static_cast(h->DepthInTokens()) > usable) { + h = h->Parent(); + } + match.host.last_node = h; + } + + match.paged_cache.restore_kind = MatchResult::PagedCache::RestoreKind::kSnapshotComplete; + match.paged_cache.replay_start_tokens = 0; +} + +std::vector HybridPrefixCache::PagedCacheGroupIds() const { + std::vector ids; + ids.reserve(paged_cache_allocators_.size()); + for (const auto& [gid, _] : paged_cache_allocators_) { + ids.push_back(gid); + } + return ids; +} + +std::int32_t HybridPrefixCache::PagedCacheGroupTotalPages(const std::string& group_id) const { + auto it = paged_cache_allocators_.find(group_id); + if (it == paged_cache_allocators_.end()) { + throw std::out_of_range("HybridPrefixCache::PagedCacheGroupTotalPages: group_id not configured"); + } + return it->second->TotalPages(); +} + +std::int32_t HybridPrefixCache::PagedCacheGroupAvailablePages(const std::string& group_id) const { + auto it = paged_cache_allocators_.find(group_id); + if (it == paged_cache_allocators_.end()) { + throw std::out_of_range("HybridPrefixCache::PagedCacheGroupAvailablePages: group_id not configured"); + } + return it->second->AvailablePages(); +} + +std::int64_t HybridPrefixCache::PagedCacheGroupFailedAllocCount(const std::string& group_id) const { + auto it = paged_cache_allocators_.find(group_id); + if (it == paged_cache_allocators_.end()) { + throw std::out_of_range("HybridPrefixCache::PagedCacheGroupFailedAllocCount: group_id not configured"); + } + return it->second->FailedAllocCount(); +} + +std::vector HybridPrefixCache::GetRequestPagedCachePageIds(const std::string& request_id, + const std::string& group_id) const { + if (paged_cache_allocators_.find(group_id) == paged_cache_allocators_.end()) { + throw std::out_of_range("HybridPrefixCache::GetRequestPagedCachePageIds: group_id not configured"); + } + auto req_it = request_paged_cache_tables_.find(request_id); + if (req_it == request_paged_cache_tables_.end()) { + return {}; + } + auto group_it = req_it->second.find(group_id); + if (group_it == req_it->second.end()) { + return {}; + } + return group_it->second.PageIds(); +} + +std::int32_t HybridPrefixCache::GetRequestPagedCacheBaseLogicalPage(const std::string& request_id, + const std::string& group_id) const { + if (paged_cache_allocators_.find(group_id) == paged_cache_allocators_.end()) { + throw std::out_of_range("HybridPrefixCache::GetRequestPagedCacheBaseLogicalPage: group_id not configured"); + } + auto req_it = request_paged_cache_tables_.find(request_id); + if (req_it == request_paged_cache_tables_.end()) { + return 0; + } + auto group_it = req_it->second.find(group_id); + if (group_it == req_it->second.end()) { + return 0; + } + return group_it->second.BaseLogicalPage(); +} + +std::map HybridPrefixCache::InitialSimulatedFree() const { + std::map out; + for (const auto& [gid, allocator] : paged_cache_allocators_) { + out[gid] = allocator->AvailablePages(); + } + return out; +} + +void HybridPrefixCache::AcquireForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + const MatchResult::PagedCache& paged_cache_hit) { + if (paged_cache_allocators_.empty()) return; + auto& tables = request_paged_cache_tables_[request_id]; + const bool has_hit = (paged_cache_hit.last_node != nullptr) && (paged_cache_hit.prefix_len_tokens > 0); + for (const auto& [group_id, allocator] : paged_cache_allocators_) { + auto it = tables.find(group_id); + const bool fresh_table = (it == tables.end()); + if (fresh_table) { + it = tables.emplace(group_id, PagedCacheGroupTable(allocator.get())).first; + // Import borrowed-prefix BEFORE ReleaseSkipped/Acquire on a fresh table. + if (has_hit) { + auto pid_it = paged_cache_hit.per_group_page_ids.find(group_id); + if (pid_it != paged_cache_hit.per_group_page_ids.end() && !pid_it->second.empty()) { + std::int32_t base_logical_page = 0; + auto base_it = paged_cache_hit.per_group_base_logical_page.find(group_id); + if (base_it != paged_cache_hit.per_group_base_logical_page.end()) { + base_logical_page = base_it->second; + } + std::vector page_ids_copy = pid_it->second; + it->second.ImportPrefixBorrowed(std::move(page_ids_copy), base_logical_page, + paged_cache_hit.prefix_len_tokens); + } + } + } + const auto& cfg = allocator->Config(); + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { + const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); + it->second.ReleaseSkipped(lower); + } + it->second.Acquire(target_raw_tokens_exclusive); + } +} + +void HybridPrefixCache::ReleaseRequest(const std::string& request_id) { + auto it = request_paged_cache_tables_.find(request_id); + if (it == request_paged_cache_tables_.end()) return; + for (auto& [_, table] : it->second) { + table.ReleaseAll(); + } + request_paged_cache_tables_.erase(it); +} + +void HybridPrefixCache::PopulateOp(ForwardOperationBase& op_base) const { + if (paged_cache_allocators_.empty()) return; + auto req_it = request_paged_cache_tables_.find(op_base.request_id); + for (const auto& [gid, allocator] : paged_cache_allocators_) { + std::vector pages; + std::int32_t base_offset = 0; + if (req_it != request_paged_cache_tables_.end()) { + auto table_it = req_it->second.find(gid); + if (table_it != req_it->second.end()) { + pages = table_it->second.PageIds(); + base_offset = table_it->second.BaseLogicalPage(); + } + } + op_base.paged_cache_pages[gid] = std::move(pages); + if (allocator->Config().retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + op_base.paged_cache_page_base_offsets[gid] = base_offset; + } + } +} + +HybridPrefixCache::PagedCacheGroupAdmission HybridPrefixCache::checkPagedCacheGroupAdmission( + const std::string& request_id, std::int32_t first_raw_position_of_op, std::int32_t target_raw_tokens_exclusive, + const std::map& simulated_free, const MatchResult::PagedCache& paged_cache_hit, + const PagedCacheAdmissionContext& context) const { + PagedCacheGroupAdmission result; + if (paged_cache_allocators_.empty() || target_raw_tokens_exclusive < 0) { + return result; + } + + auto req_it = + context.fresh_table_view ? request_paged_cache_tables_.end() : request_paged_cache_tables_.find(request_id); + const bool has_hit = (paged_cache_hit.last_node != nullptr) && (paged_cache_hit.prefix_len_tokens > 0); + for (const auto& [gid, allocator] : paged_cache_allocators_) { + const auto& cfg = allocator->Config(); + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (cfg.entry_stride_tokens <= 0 || cfg.rows_per_page <= 0 || raw_per_page <= 0) { + continue; + } + + const std::int32_t entries = CeilDivPositive(target_raw_tokens_exclusive, cfg.entry_stride_tokens); + const std::int32_t required = (entries + cfg.rows_per_page - 1) / cfg.rows_per_page; + + std::int32_t current_size = 0; + std::int32_t current_active = 0; + std::int32_t borrowed_in_table = 0; + std::int32_t owned_in_table = 0; + std::int32_t already_released = 0; + bool table_exists = false; + if (req_it != request_paged_cache_tables_.end()) { + auto t_it = req_it->second.find(gid); + if (t_it != req_it->second.end()) { + table_exists = true; + current_size = t_it->second.Size(); + current_active = t_it->second.ActivePagesCount(); + borrowed_in_table = t_it->second.BorrowedPagesCount(); + owned_in_table = t_it->second.OwnedPagesCount(); + already_released = t_it->second.ReleasedPagesCount(); + } + } + + std::int32_t borrowed_count = 0; + std::int32_t borrowed_base = 0; + if (has_hit && !table_exists) { + auto pid_it = paged_cache_hit.per_group_page_ids.find(gid); + if (pid_it != paged_cache_hit.per_group_page_ids.end()) { + borrowed_count = static_cast(pid_it->second.size()); + } + auto base_it = paged_cache_hit.per_group_base_logical_page.find(gid); + if (base_it != paged_cache_hit.per_group_base_logical_page.end()) { + borrowed_base = base_it->second; + } + } + + std::int32_t releasable_total = 0; + std::int32_t releasable_owned = 0; + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { + const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); + const std::int32_t target_releases = lower / raw_per_page; + const std::int32_t logical_released_base = table_exists ? already_released : borrowed_base; + releasable_total = std::max(0, target_releases - logical_released_base); + releasable_total = std::min(releasable_total, current_active + borrowed_count); + + // Borrowed pages drop the index only (no pool credit); only the + // owned-prefix slice contributes to releasable_owned. + const std::int32_t borrowed_present_total = table_exists ? borrowed_in_table : borrowed_count; + releasable_owned = releasable_total - std::min(releasable_total, borrowed_present_total); + if (table_exists) { + releasable_owned = std::min(releasable_owned, owned_in_table); + } + } + + const std::int32_t absolute_have = + table_exists ? (already_released + current_size) : (borrowed_base + borrowed_count); + const std::int32_t new_pages = std::max(0, required - absolute_have); + std::int32_t free = allocator->AvailablePages(); + auto sf_it = simulated_free.find(gid); + if (sf_it != simulated_free.end()) { + free = sf_it->second; + } + auto credit_it = context.owned_release_credit.find(gid); + if (credit_it != context.owned_release_credit.end()) { + free += credit_it->second; + } + + result.releasable_owned_pages[gid] = releasable_owned; + result.new_pages_needed[gid] = new_pages; + if (free + releasable_owned < new_pages) { + result.ok = false; + } + } + return result; +} + +void HybridPrefixCache::applyPagedCacheGroupAdmissionDebit(std::map& simulated_free, + const PagedCacheGroupAdmission& admission) { + for (const auto& [gid, releasable_owned] : admission.releasable_owned_pages) { + simulated_free[gid] += releasable_owned; + } + for (const auto& [gid, new_pages] : admission.new_pages_needed) { + simulated_free[gid] -= new_pages; + } +} + +HybridPrefixCache::AdmissionFailureKind HybridPrefixCache::ClassifyAdmissionFailure( + const PagedCacheGroupAdmission& admission) const { + if (admission.ok) return AdmissionFailureKind::kNone; + bool history_starved = false; + bool state_starved = false; + for (const auto& [gid, needed] : admission.new_pages_needed) { + if (needed <= 0) continue; + if (paged_cache_history_group_set_.find(gid) != paged_cache_history_group_set_.end()) { + history_starved = true; + } + if (paged_cache_state_group_set_.find(gid) != paged_cache_state_group_set_.end()) { + state_starved = true; + } + } + if (history_starved && state_starved) return AdmissionFailureKind::kBothStarved; + if (history_starved) return AdmissionFailureKind::kHistoryStarved; + if (state_starved) return AdmissionFailureKind::kStateStarved; + return AdmissionFailureKind::kNone; +} + +void HybridPrefixCache::refreshPagedCacheSimulatedFree(std::map& simulated_free) const { + for (const auto& [gid, allocator] : paged_cache_allocators_) { + simulated_free[gid] = allocator->AvailablePages(); + } +} + +bool HybridPrefixCache::admitPagedCacheChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit, + const PagedCacheAdmissionContext& context) { + PagedCacheGroupAdmission admission = checkPagedCacheGroupAdmission( + request_id, first_raw_position_of_op, target_raw_tokens_exclusive, simulated_free, paged_cache_hit, context); + const std::size_t prune_budget = paged_cache_snapshot_nodes_.size(); + for (std::size_t pruned = 0; !admission.ok && pruned < prune_budget; ++pruned) { + AdmissionFailureKind kind = ClassifyAdmissionFailure(admission); + if (kind == AdmissionFailureKind::kNone) break; + if (!tryPrunePagedCacheSnapshot(kind)) break; + refreshPagedCacheSimulatedFree(simulated_free); + admission = checkPagedCacheGroupAdmission(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, + simulated_free, paged_cache_hit, context); + } + if (!admission.ok) return false; + for (const auto& [gid, credit] : context.owned_release_credit) { + simulated_free[gid] += credit; + } + applyPagedCacheGroupAdmissionDebit(simulated_free, admission); + return true; +} + +bool HybridPrefixCache::DetachStateSnapshotFromNode(TreeNode* node) { + if (node == nullptr) return false; + PagedCacheSnapshot* snap = node->GetPagedCacheSnapshotMut(); + if (snap == nullptr) return false; + bool removed_any = false; + for (const auto& gid : paged_cache_state_groups_) { + auto it = snap->groups.find(gid); + if (it != snap->groups.end()) { + snap->groups.erase(it); + removed_any = true; + } + } + if (!removed_any) return false; + snap->complete_families.erase(PagedCacheGroupFamily::State); + // If nothing remains, fall through to full detach to keep invariants tidy. + if (snap->groups.empty()) { + DetachPagedCacheSnapshotFromNode(node); + } + return true; +} + +bool HybridPrefixCache::tryPrunePagedCacheSnapshot(AdmissionFailureKind kind) { + if (!HasPagedCacheAdjunct()) return false; + if (kind == AdmissionFailureKind::kNone) return false; + + auto is_pinned = [](TreeNode* node) { + for (TreeNode* cur = node; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { + if (!cur->OnDevice()) continue; + if (cur->Device().RefCount() > 0) return true; + } + return false; + }; + + // Sort once and share between branches: oldest first, then deepest within + // same Time(). Both try_state_only and try_full walk this same order. + std::vector candidates; + candidates.reserve(paged_cache_snapshot_nodes_.size()); + for (TreeNode* node : paged_cache_snapshot_nodes_) { + if (node == nullptr) continue; + if (!node->HasPagedCacheSnapshot()) continue; + candidates.push_back(node); + } + std::sort(candidates.begin(), candidates.end(), [](TreeNode* a, TreeNode* b) { + if (a->Time() != b->Time()) return a->Time() < b->Time(); + return a->DepthInTokens() > b->DepthInTokens(); + }); + + auto try_state_only = [&]() { + for (TreeNode* node : candidates) { + if (is_pinned(node)) continue; + const auto* snap = node->GetPagedCacheSnapshot(); + if (snap == nullptr) continue; + if (!snap->IsCompleteFor(PagedCacheGroupFamily::State)) continue; + if (DetachStateSnapshotFromNode(node)) return true; + } + return false; + }; + + auto try_full = [&]() { + TreeNode* victim = nullptr; + for (TreeNode* node : candidates) { + if (is_pinned(node)) continue; + victim = node; + break; + } + if (victim == nullptr) return false; + const std::size_t victim_depth = victim->DepthInTokens(); + auto primary = DetachPagedCacheSnapshotFromNode(victim); + (void)primary; + std::vector descendants; + for (TreeNode* node : paged_cache_snapshot_nodes_) { + if (node == nullptr || node == victim) continue; + if (!node->HasPagedCacheSnapshot()) continue; + if (node->DepthInTokens() <= victim_depth) continue; + for (TreeNode* cur = node->Parent(); cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { + if (cur == victim) { + descendants.push_back(node); + break; + } + } + } + for (TreeNode* d : descendants) { + if (is_pinned(d)) continue; + auto cascaded = DetachPagedCacheSnapshotFromNode(d); + (void)cascaded; + } + return true; + }; + + // kBothStarved: state-only cannot solve history shortage; go straight to + // full. The outer admit loop will re-classify if state still needs more. + switch (kind) { + case AdmissionFailureKind::kStateStarved: + return try_state_only(); + case AdmissionFailureKind::kHistoryStarved: + case AdmissionFailureKind::kBothStarved: + return try_full(); + case AdmissionFailureKind::kNone: + return false; + } + return false; +} + +bool HybridPrefixCache::AdmitChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit) { + return admitPagedCacheChunk(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, simulated_free, + paged_cache_hit, {}); +} + +bool HybridPrefixCache::AdmitChunkFromRetracted(const std::string& request_id, std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit) { + PagedCacheAdmissionContext context{.fresh_table_view = true}; + auto req_it = request_paged_cache_tables_.find(request_id); + if (req_it != request_paged_cache_tables_.end()) { + for (const auto& [gid, table] : req_it->second) { + context.owned_release_credit[gid] = table.OwnedPagesCount(); + } + } + return admitPagedCacheChunk(request_id, 0, target_raw_tokens_exclusive, simulated_free, paged_cache_hit, context); +} + +void HybridPrefixCache::CommitChunk(const std::string& request_id, TreeNode* terminal) { + if (!HasPagedCacheAdjunct()) return; + if (terminal == nullptr) return; + + auto tables_it = request_paged_cache_tables_.find(request_id); + if (tables_it == request_paged_cache_tables_.end()) return; + auto& tables = tables_it->second; + + const std::int32_t lcm = paged_cache_history_alignment_tokens_; + if (lcm <= 0) return; + const auto& required_groups = paged_cache_required_groups_; + if (required_groups.empty()) return; + + auto canonical_it = tables.find(required_groups.front()); + if (canonical_it == tables.end()) return; + std::int32_t last_committed = canonical_it->second.CommittedPrefixLenTokens(); + + const std::int32_t chunk_depth = static_cast(terminal->DepthInTokens()); + if (chunk_depth <= 0) return; + + while (last_committed + lcm <= chunk_depth) { + const std::int32_t target = last_committed + lcm; + + TreeNode* attach_node = kv_prefix_cache_.GetRadixTree().SplitAt(terminal, target); + if (attach_node == nullptr) break; + + if (attach_node->HasPagedCacheSnapshot()) { + bool covered = true; + for (const auto& gid : required_groups) { + auto t_it = tables.find(gid); + if (t_it == tables.end()) { + covered = false; + break; + } + if (t_it->second.CommittedPrefixLenTokens() < target) { + covered = false; + break; + } + } + if (!covered) { + spdlog::warn( + "[HybridPrefixCache] CommitChunk: target depth {} already has a paged-cache " + "snapshot but request {} has uncommitted owned pages in [{}, {}); leaving " + "existing snapshot intact", + target, request_id, last_committed, target); + break; + } + last_committed = target; + continue; + } + + bool preflight_ok = true; + for (const auto& gid : required_groups) { + auto t_it = tables.find(gid); + if (t_it == tables.end()) { + preflight_ok = false; + break; + } + const auto& table = t_it->second; + const std::int32_t raw_per_page = table.RawTokensPerPage(); + if (raw_per_page <= 0) { + preflight_ok = false; + break; + } + if (table.CommittedPrefixLenTokens() % raw_per_page != 0) { + preflight_ok = false; + break; + } + if (target % raw_per_page != 0) { + preflight_ok = false; + break; + } + if (target <= table.CommittedPrefixLenTokens()) { + preflight_ok = false; + break; + } + if (target > table.RawTokenCursor()) { + preflight_ok = false; + break; + } + } + if (!preflight_ok) { + spdlog::warn( + "[HybridPrefixCache] CommitChunk: preflight failed for request {} at target " + "depth {}; leaving prior commits intact", + request_id, target); + break; + } + + auto snapshot = std::make_unique(); + snapshot->prefix_len_tokens = target; + for (const auto& gid : required_groups) { + auto& table = tables.find(gid)->second; + auto group_alloc_it = paged_cache_allocators_.find(gid); + const auto& cfg = group_alloc_it->second->Config(); + auto result = cfg.family == PagedCacheGroupFamily::History ? table.CommitHistoryToSnapshot(target) + : table.CheckpointStateToSnapshot(target); + PagedCacheGroupSnapshot group_snap{}; + group_snap.pages = std::move(result.pages); + group_snap.base_logical_page = result.segment_base_logical_page; + group_snap.raw_token_cursor = table.RawTokenCursor(); + group_snap.sliding = table.IsSliding(); + snapshot->groups.emplace(gid, std::move(group_snap)); + } + + bool snapshot_complete = true; + for (const auto& gid : required_groups) { + if (snapshot->groups.find(gid) == snapshot->groups.end()) { + snapshot_complete = false; + break; + } + } + _assert(snapshot_complete, + "HybridPrefixCache::CommitChunk: built snapshot missing a required group after " + "preflight+commit; invariant violated"); + const bool attached = AttachPagedCacheSnapshotToNode(attach_node, std::move(snapshot)); + _assert(attached, + "HybridPrefixCache::CommitChunk: attach rejected a non-null snapshot on a non-null " + "node; invariant violated"); + + last_committed = target; + } +} + } // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h index 89247674c..e69336fc2 100644 --- a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h @@ -20,11 +20,18 @@ #pragma once +#include #include +#include #include #include +#include +#include +#include +#include #include +#include "resource/allocator/paged_cache_group.h" #include "resource/hybrid_prefix_cache/mamba_eviction_manager.h" #include "resource/radix_tree/mamba_slot.h" #include "resource/kv_prefix_cache/kv_prefix_cache.h" @@ -33,9 +40,11 @@ namespace tokenspeed { class MambaChunkAllocator; +class ForwardOperationBase; class HybridPrefixCache { public: + // `mamba_allocator` may be null; paged-cache adjunct is enabled separately. HybridPrefixCache(KVPrefixCache& prefix_cache, MambaChunkAllocator* allocator, std::int32_t mamba_cache_chunk_size); MatchResult Match(const token_vec_t& token_ids, MatchIntent intent = MatchIntent::PrefixReuse); @@ -47,19 +56,153 @@ class HybridPrefixCache { std::int32_t AlignMambaCacheSeqlen(std::int32_t seqlen) const; TreeNode* FindLastMambaNode(TreeNode* from) const; - // CallBack on KV Prefix Cache Eviction + // Takes ownership. Duplicate group_id throws std::invalid_argument. + void RegisterPagedCacheGroup(std::unique_ptr allocator); + + // History alignment is the LCM of RawTokensPerPage() over the History-family + // groups; state groups only need the trailing window. Sliding groups must + // have a window entry; full-history groups must not. + void EnablePagedCacheAdjunct(std::vector required_groups, + std::unordered_map sliding_window_per_group, + StateRestorePolicy policy = StateRestorePolicy::kSnapshotRequired); + + bool HasMambaAdjunct() const { return mamba_allocator_ != nullptr; } + bool HasPagedCacheAdjunct() const { return paged_cache_history_alignment_tokens_ > 0; } + std::int32_t PagedCacheHistoryAlignmentTokens() const { return paged_cache_history_alignment_tokens_; } + const std::vector& PagedCacheRequiredGroups() const { return paged_cache_required_groups_; } + + // Group introspection: throws std::out_of_range on unknown group_id. + std::vector PagedCacheGroupIds() const; + std::int32_t PagedCacheGroupTotalPages(const std::string& group_id) const; + std::int32_t PagedCacheGroupAvailablePages(const std::string& group_id) const; + std::int64_t PagedCacheGroupFailedAllocCount(const std::string& group_id) const; + + // Per-request introspection: unknown group_id throws; unknown request_id returns empty. + std::vector GetRequestPagedCachePageIds(const std::string& request_id, + const std::string& group_id) const; + std::int32_t GetRequestPagedCacheBaseLogicalPage(const std::string& request_id, const std::string& group_id) const; + + // Unified paged-cache lifecycle surface used by the Scheduler. All methods + // below are no-ops when no paged-cache groups are registered. + + // Initial per-group simulated_free budget mirroring live allocator state. + std::map InitialSimulatedFree() const; + + // Ensure tables exist and cover [first_raw_position_of_op, target_raw_tokens_exclusive). + // Borrowed prefix is imported BEFORE any fresh allocation on a fresh table. + void AcquireForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + const MatchResult::PagedCache& paged_cache_hit = {}); + + // Owned pages return to the pool via OwnedPages RAII; borrowed ids are dropped. + void ReleaseRequest(const std::string& request_id); + + // Fill op.paged_cache_pages / op.paged_cache_page_base_offsets from the tables. + void PopulateOp(ForwardOperationBase& op_base) const; + + // Run admission against `simulated_free`; prunes evictable snapshots on + // group-pool pressure, then applies the debit on success. + bool AdmitChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit = {}); + + // Retract-decode variant: admission uses a fresh-table view and credits + // pages owned by the stale table before it is released. + bool AdmitChunkFromRetracted(const std::string& request_id, std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit); + + // Commit newly-written full LCM segments into TreeNode PagedCacheSnapshots. + void CommitChunk(const std::string& request_id, TreeNode* terminal); + + // Attach a snapshot to `node`, computing `complete_families` from which + // required-per-family group ids are present and registering the node in + // `paged_cache_snapshot_nodes_`. Returns false when either argument is + // null (defensive no-op). Accepts partial snapshots; the per-policy + // "snapshot must be full" invariant is enforced upstream by CommitChunk. + bool AttachPagedCacheSnapshotToNode(TreeNode* node, std::unique_ptr snapshot); + + // Drops `node` from the membership set, then detaches and returns the snapshot. + std::unique_ptr DetachPagedCacheSnapshotFromNode(TreeNode* node); + + // Callback from KV prefix-cache eviction. void OnKVEvict(TreeNode* node); std::int32_t AvailableSlots() const; KVPrefixCache& GetKVPrefixCache() { return kv_prefix_cache_; } private: + friend class HybridPrefixCacheTestPeer; + + // Per-family classification of admission failure; drives state-only vs + // full prune strategy. + enum class AdmissionFailureKind { kNone, kHistoryStarved, kStateStarved, kBothStarved }; + + struct PagedCacheGroupAdmission { + bool ok{true}; + std::map releasable_owned_pages{}; + std::map new_pages_needed{}; + }; + + struct PagedCacheAdmissionContext { + bool fresh_table_view{false}; + std::map owned_release_credit{}; + }; + + // Classify which family caused `admission.ok == false`. + AdmissionFailureKind ClassifyAdmissionFailure(const PagedCacheGroupAdmission& admission) const; + + // Drop only state-family groups from `node`'s snapshot; history portion + // remains and the node stays registered. Returns true iff state groups removed. + bool DetachStateSnapshotFromNode(TreeNode* node); + void augmentMatch(MatchResult& match) const; + void augmentMatchPagedCache(MatchResult& match) const; + + // Detach oldest evictable snapshot to free pool pages. State-only path is + // used only on kStateStarved; history/both go to full cascade. + bool tryPrunePagedCacheSnapshot(AdmissionFailureKind kind); + + bool admitPagedCacheChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit, + const PagedCacheAdmissionContext& context); + + // Build admission record without mutating any table. + PagedCacheGroupAdmission checkPagedCacheGroupAdmission(const std::string& request_id, + std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + const std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit, + const PagedCacheAdmissionContext& context) const; + + // Owned releases credit, new-page needs debit. + static void applyPagedCacheGroupAdmissionDebit(std::map& simulated_free, + const PagedCacheGroupAdmission& admission); + void refreshPagedCacheSimulatedFree(std::map& simulated_free) const; KVPrefixCache& kv_prefix_cache_; MambaChunkAllocator* mamba_allocator_; MambaEvictionManager mamba_eviction_manager_; std::int32_t mamba_cache_chunk_size_; + + // `paged_cache_history_alignment_tokens_ == 0` means adjunct disabled; tables still work. + std::map> paged_cache_allocators_; + std::unordered_map> request_paged_cache_tables_; + std::int32_t paged_cache_history_alignment_tokens_{0}; + std::vector paged_cache_required_groups_; + std::unordered_map paged_cache_sliding_window_per_group_; + // Subset of `paged_cache_required_groups_` partitioned by family. + std::vector paged_cache_history_groups_; + std::vector paged_cache_state_groups_; + // Fast hot-path lookup mirrors of the above (filled in EnablePagedCacheAdjunct). + std::unordered_set paged_cache_history_group_set_; + std::unordered_set paged_cache_state_group_set_; + StateRestorePolicy paged_cache_state_policy_{StateRestorePolicy::kSnapshotRequired}; + + // TODO(snapshot-lru-perf): O(N log N) per prune; swap in LRU index if profiling shows it matters. + std::unordered_set paged_cache_snapshot_nodes_; }; } // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h index aaba00dae..06ec7f1c4 100644 --- a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h +++ b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h @@ -79,6 +79,11 @@ class KVPrefixCache { std::int32_t PageSize() const { return tree_.PageSize(); } DeviceManager& GetDeviceManager() { return device_; } + // Adjunct managers may need to materialize boundary nodes via SplitAt. + // The tree's lifetime remains owned by this KVPrefixCache. + RadixTree& GetRadixTree() { return tree_; } + const RadixTree& GetRadixTree() const { return tree_; } + private: MatchResult RootMatch() const; diff --git a/tokenspeed-scheduler/csrc/resource/radix_tree/paged_cache_snapshot.h b/tokenspeed-scheduler/csrc/resource/radix_tree/paged_cache_snapshot.h new file mode 100644 index 000000000..4b24aa02b --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/radix_tree/paged_cache_snapshot.h @@ -0,0 +1,49 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +#pragma once + +#include +#include +#include +#include + +#include "resource/allocator/owned_pages.h" +#include "resource/allocator/paged_cache_group.h" + +namespace tokenspeed { + +// Per-group snapshot held by a TreeNode. RAII returns pages to the allocator. +struct PagedCacheGroupSnapshot { + OwnedPages pages; + std::int32_t base_logical_page{0}; + std::int32_t raw_token_cursor{0}; + bool sliding{false}; +}; + +// Snapshot for a TreeNode at a history-alignment-aligned raw-token boundary; +// completeness is tracked per family. +struct PagedCacheSnapshot { + std::int32_t prefix_len_tokens{0}; + std::map groups; + // Filled by HybridPrefixCache::AttachPagedCacheSnapshotToNode based on + // which group ids landed in `groups` vs required-per-family lists. + std::set complete_families; + + bool IsCompleteFor(PagedCacheGroupFamily f) const { return complete_families.find(f) != complete_families.end(); } +}; + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.cpp b/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.cpp index afd597f24..fd9fc7072 100644 --- a/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.cpp +++ b/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.cpp @@ -89,6 +89,42 @@ TreeNode* RadixTree::PruneEmptyByNode(TreeNode* node) { return current; } +TreeNode* RadixTree::SplitAt(TreeNode* descendant, std::int32_t depth_in_tokens) { + if (descendant == nullptr) { + return nullptr; + } + if (depth_in_tokens <= 0 || depth_in_tokens % page_size_ != 0) { + return nullptr; + } + if (depth_in_tokens > static_cast(descendant->DepthInTokens())) { + return nullptr; + } + + // Find the ancestor range covering depth_in_tokens. + // Exact match returns the node; an interior split returns the prefix. + TreeNode* current = descendant; + while (current != nullptr && !current->IsRoot()) { + const std::int32_t this_depth = static_cast(current->DepthInTokens()); + const std::int32_t parent_depth = this_depth - static_cast(current->Tokens().size()); + if (depth_in_tokens == this_depth) { + return current; + } + if (depth_in_tokens > parent_depth && depth_in_tokens < this_depth) { + // Refuse to split a snapshot-bearing node (would dangle borrowed ids). + if (current->HasPagedCacheSnapshot()) { + return nullptr; + } + TreeNode* parent = current->Parent(); + const token_vec_t child_key = getFirstPage(current->Tokens(), page_size_); + const std::size_t prefix_pages = static_cast((depth_in_tokens - parent_depth) / page_size_); + SplitResult sr = splitChild(parent, child_key, prefix_pages); + return sr.prefix; + } + current = current->Parent(); + } + return nullptr; +} + WalkResult RadixTree::WalkDownUtilMismatch(token_slice aligned_tokens, TreeNode::timestamp_t access_time, TreeNode* start_node) { TreeNode* current = (start_node != nullptr) ? start_node : root_.get(); @@ -131,6 +167,10 @@ WalkResult RadixTree::WalkDownUtilMismatch(token_slice aligned_tokens, TreeNode: break; } if (matched_num_pages != static_cast(child->Tokens().size() / page_size_)) { + // Refuse to split a snapshot-bearing node; borrowed ids rely on it. + if (child->HasPagedCacheSnapshot()) { + break; + } SplitResult split = splitChild(current, walk_key_cache, matched_num_pages); child = split.prefix; } diff --git a/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.h b/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.h index 9be0d0033..3514bb4e7 100644 --- a/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.h +++ b/tokenspeed-scheduler/csrc/resource/radix_tree/radix_tree.h @@ -58,6 +58,11 @@ class RadixTree { TreeNode* PruneEmptyByNode(TreeNode* node); + // Find or create the node at depth_in_tokens on descendant's root path. + // depth_in_tokens must be page-aligned and within descendant's depth. + // Returns nullptr for root depth or invalid inputs. + TreeNode* SplitAt(TreeNode* descendant, std::int32_t depth_in_tokens); + private: SplitResult splitChild(TreeNode* parent, const token_vec_t& child_key, std::size_t prefix_pages); diff --git a/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.cpp b/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.cpp index 0805e711e..7a9e25dfa 100644 --- a/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.cpp +++ b/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.cpp @@ -92,7 +92,20 @@ void TreeNode::SplitSelfInto(TreeNode& prefix, std::size_t prefix_pages, std::in std::int32_t ref_count = host_resource_->RefCount(); prefix.AttachResource(std::make_unique(host_resource_->SplitFirst(prefix_pages), ref_count)); } - // Mamba resources stay in suffix node, no special handling needed + // Mamba stays in suffix. + // Invariant: snapshot-bearing nodes are never split (RadixTree refuses). + // A split here would dangle borrowed ids in active requests. + _assert(paged_cache_snapshot_ == nullptr, + "TreeNode::SplitSelfInto called on a node with an attached paged-cache snapshot; " + "splitting would invalidate borrowed page id references in active requests"); +} + +void TreeNode::AttachPagedCacheSnapshot(std::unique_ptr snapshot) { + paged_cache_snapshot_ = std::move(snapshot); +} + +std::unique_ptr TreeNode::DetachPagedCacheSnapshot() { + return std::move(paged_cache_snapshot_); } void TreeNode::SetPersisted(bool persisted) { diff --git a/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.h b/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.h index 9da8095fa..e4f91203d 100644 --- a/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.h +++ b/tokenspeed-scheduler/csrc/resource/radix_tree/tree_node.h @@ -32,6 +32,7 @@ #include #include "resource/radix_tree/mamba_slot.h" +#include "resource/radix_tree/paged_cache_snapshot.h" #include "resource/radix_tree/tree_resource.h" #include "resource/types.h" @@ -106,6 +107,11 @@ class TreeNode { void AttachMamba(std::unique_ptr slot) { mamba_slot_ = std::move(slot); } std::unique_ptr DetachMamba() { return std::move(mamba_slot_); } + // Paged-cache snapshot adjunct. Completeness is now per-family on the + // snapshot itself (see `PagedCacheSnapshot::IsCompleteFor`). + bool HasPagedCacheSnapshot() const { return paged_cache_snapshot_ != nullptr; } + const PagedCacheSnapshot* GetPagedCacheSnapshot() const { return paged_cache_snapshot_.get(); } + std::optional CacheOpId() const; void SetPersisted(bool persisted = true); @@ -117,6 +123,14 @@ class TreeNode { void SplitSelfInto(TreeNode& prefix, std::size_t prefix_pages, std::int32_t page_size); +private: + // Private so all attach/detach routes through HybridPrefixCache and keeps + // its `paged_cache_snapshot_nodes_` membership set in sync. + friend class HybridPrefixCache; + void AttachPagedCacheSnapshot(std::unique_ptr snapshot); + std::unique_ptr DetachPagedCacheSnapshot(); + PagedCacheSnapshot* GetPagedCacheSnapshotMut() { return paged_cache_snapshot_.get(); } + private: TreeNode* parent_{}; children_map_t children_{}; @@ -129,6 +143,7 @@ class TreeNode { std::unique_ptr device_resource_{}; std::unique_ptr host_resource_{}; std::unique_ptr mamba_slot_{}; + std::unique_ptr paged_cache_snapshot_{}; }; template diff --git a/tokenspeed-scheduler/csrc/resource/types.h b/tokenspeed-scheduler/csrc/resource/types.h index bc874933f..aca10a7b7 100644 --- a/tokenspeed-scheduler/csrc/resource/types.h +++ b/tokenspeed-scheduler/csrc/resource/types.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -69,6 +70,22 @@ struct MatchResult { // Mamba extension (default: no mamba cache, -1 = inactive) std::int32_t mamba_branching_seqlen{-1}; std::int32_t mamba_cow_src_index{-1}; + + // Paged-cache adjunct hit. Null last_node or zero prefix means no hit. + // When hit, device/host last_node also sit at or before prefix_len_tokens. + // base_logical_page is 0 for full-history groups; > 0 for sliding windows. + // TODO(match-result-pagedcache-zero-copy): return snapshot+depth and walk on demand. + struct PagedCache { + // Phase 1 hit kinds; Phase 2 will add kReplay variants. + enum class RestoreKind { kSnapshotComplete }; + TreeNode* last_node{nullptr}; + std::int32_t prefix_len_tokens{0}; + std::map> per_group_page_ids; + std::map per_group_base_logical_page; + RestoreKind restore_kind{RestoreKind::kSnapshotComplete}; + // Phase 2 placeholder; Phase 1 always 0. + std::int32_t replay_start_tokens{0}; + } paged_cache; }; struct InsertResult { diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index c4f3d4a73..192e5d251 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -53,15 +53,6 @@ #include "scheduler/types.h" #include "utils.h" -namespace { - -std::int32_t CeilDivPositive(std::int32_t numer, std::int32_t denom) { - if (numer <= 0) return 0; - return (numer + denom - 1) / denom; -} - -} // namespace - namespace tokenspeed { std::optional Scheduler::schedulePrefillFirstChunk( @@ -87,7 +78,7 @@ std::optional Scheduler::schedulePrefillFir } std::int32_t tokens_this_round = std::min(remaining, unscheduled); - if (hybrid_prefix_cache_ && match_result.mamba_branching_seqlen == -1) { + if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && match_result.mamba_branching_seqlen == -1) { const std::int32_t aligned = hybrid_prefix_cache_->AlignMambaCacheSeqlen(tokens_this_round); if (aligned > 0) { match_result.mamba_branching_seqlen = aligned; @@ -99,22 +90,22 @@ std::optional Scheduler::schedulePrefillFir std::unique_ptr temp_lock = std::make_unique(match_result.device.last_node); - // Eviction happens Here: evicts unlocked prefix-cache nodes to free device_pages_needed pages. + // Evict unlocked prefix-cache nodes before allocating request-local pages. if (!(kv_prefix_cache_.EnsureCapacityByEvict(device_pages_needed))) { return {}; } - if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(2)) { + if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && + !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(2)) { return {}; } const std::int32_t first_pos = request->PrefillSize() - unscheduled; const std::int32_t target = first_pos + tokens_this_round; - auto admission = checkPagedCacheGroupAdmission(request->Id(), first_pos, target, simulated_free); - if (!admission.ok) { + if (hybrid_prefix_cache_ && + !hybrid_prefix_cache_->AdmitChunk(request->Id(), first_pos, target, simulated_free, match_result.paged_cache)) { return {}; } - applyPagedCacheGroupAdmissionDebit(simulated_free, admission); return fsm::SchedulePrefillFirstChunkEvent{ tokens_this_round, @@ -143,17 +134,16 @@ std::optional Scheduler::schedulePrefill( return {}; } - if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(1)) { + if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && + !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(1)) { return {}; } const std::int32_t first_pos = request->PrefillSize() - unscheduled; const std::int32_t target = first_pos + tokens_this_round; - auto admission = checkPagedCacheGroupAdmission(request->Id(), first_pos, target, simulated_free); - if (!admission.ok) { + if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->AdmitChunk(request->Id(), first_pos, target, simulated_free)) { return {}; } - applyPagedCacheGroupAdmissionDebit(simulated_free, admission); return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; @@ -171,11 +161,9 @@ std::optional Scheduler::scheduleDecode(Request* reque const std::int32_t first_pos = request->TokenSize(); const std::int32_t target = first_pos + config_.decode_input_tokens; - auto admission = checkPagedCacheGroupAdmission(request->Id(), first_pos, target, simulated_free); - if (!admission.ok) { + if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->AdmitChunk(request->Id(), first_pos, target, simulated_free)) { return {}; } - applyPagedCacheGroupAdmissionDebit(simulated_free, admission); return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; @@ -185,7 +173,10 @@ std::optional Scheduler::scheduleDecodeFr Request* request, std::map& simulated_free) { if (req_pool_allocator_.AvailableSlots() == 0) return {}; - MatchResult match_result = kv_prefix_cache_.Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery); + MatchResult match_result = + hybrid_prefix_cache_ + ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery) + : kv_prefix_cache_.Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery); std::vector loadback_diff = match_result.NodesWithout(); TreeNode* mamba_recovery_node = nullptr; if (hybrid_prefix_cache_ && mamba_allocator_) { @@ -224,42 +215,11 @@ std::optional Scheduler::scheduleDecodeFr } } - std::map released_back; - auto req_it = request_paged_cache_tables_.find(request->Id()); - if (req_it != request_paged_cache_tables_.end()) { - for (const auto& [gid, table] : req_it->second) { - released_back[gid] = table.ActivePagesCount(); - } - } - auto simulated_after_release = simulated_free; - for (const auto& [gid, n] : released_back) { - simulated_after_release[gid] += n; - } - - PagedCacheGroupAdmission admission; const std::int32_t target = request->TokenSize(); - if (!paged_cache_allocators_.empty() && target >= 0) { - for (const auto& [gid, allocator] : paged_cache_allocators_) { - const auto& cfg = allocator->Config(); - if (cfg.entry_stride_tokens <= 0 || cfg.rows_per_page <= 0) continue; - const std::int32_t entries = CeilDivPositive(target, cfg.entry_stride_tokens); - const std::int32_t required = (entries + cfg.rows_per_page - 1) / cfg.rows_per_page; - const std::int32_t free = - simulated_after_release.count(gid) ? simulated_after_release.at(gid) : allocator->AvailablePages(); - admission.releasable_pages[gid] = 0; - admission.new_pages_needed[gid] = required; - if (free < required) { - admission.ok = false; - } - } - } - if (!admission.ok) { + if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->AdmitChunkFromRetracted(request->Id(), target, simulated_free, + match_result.paged_cache)) { return {}; } - for (const auto& [gid, n] : released_back) { - simulated_free[gid] += n; - } - applyPagedCacheGroupAdmissionDebit(simulated_free, admission); return fsm::ScheduleDecodeFromRetractedEvent{ config_.decode_input_tokens, @@ -322,20 +282,16 @@ LoadBackOperation GenerateLoadBackOp(const std::vector& diff, cache_o std::optional Scheduler::applyEventAndGenerateOp(Request* request, fsm::ScheduleRetractEvent event) { - // ScheduleRetractEvent::operator() already builds the (device_page, host_page) pairs - // inside the state transition (consistent with FinishEvent→Draining path). - // We just apply the event and read back the pre-computed pairs. + // Event applier builds the (device_page, host_page) pairs. request->Apply(std::move(event)); const auto& pages_to_transfer = request->GetPagesToTransfer(); if (pages_to_transfer.empty()) { - // device.matched == host.matched: no device→host copy needed. - // Fire WriteBackDoneEvent immediately so the request transitions - // Retracting → Retracted without registering a dangling op_id. + // No copy needed; advance Retracting to Retracted without an op_id. request->Apply(fsm::WriteBackDoneEvent{}); return std::nullopt; } - // Register in cache_op_tracker_ so WriteBackDone can route back to the request. + // Register op_id so WriteBackDone can route back. cache_op_id op_id = kv_prefix_cache_.AllocateCacheOpId(); CacheOpSpec spec; spec.request_id = request->Id(); @@ -395,17 +351,31 @@ static PrefillOperation applyPrefillEvent(Request* request, Event event) { PrefillOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::SchedulePrefillFirstChunkEvent event) { auto match = event.GetMatchResult(); auto op = applyPrefillEvent(request, std::move(event)); - op.mamba_cow_src_idx = match.mamba_cow_src_index; - op.mamba_branching_seqlen = match.mamba_branching_seqlen; - acquirePagedCachePagesForRequest(op.request_id, op.extend_prefix_len, op.extend_prefix_len + op.input_length); - populatePagedCachePagesForOp(op); + // Mamba fields only when adjunct is active. + if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct()) { + op.mamba_cow_src_idx = match.mamba_cow_src_index; + op.mamba_branching_seqlen = match.mamba_branching_seqlen; + } + // Order: attach, acquire, populate. Attach before acquire so prior-chunk + // tail pages commit into snapshots before Acquire's ReleaseSkipped frees them. + if (hybrid_prefix_cache_) { + hybrid_prefix_cache_->CommitChunk(op.request_id, const_cast(request->GetDeviceNode())); + hybrid_prefix_cache_->AcquireForRequest(op.request_id, op.extend_prefix_len, + op.extend_prefix_len + op.input_length, match.paged_cache); + hybrid_prefix_cache_->PopulateOp(op); + } return op; } PrefillOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::SchedulePrefillEvent event) { auto op = applyPrefillEvent(request, std::move(event)); - acquirePagedCachePagesForRequest(op.request_id, op.extend_prefix_len, op.extend_prefix_len + op.input_length); - populatePagedCachePagesForOp(op); + // Order: attach, acquire, populate (see SchedulePrefillFirstChunkEvent). + if (hybrid_prefix_cache_) { + hybrid_prefix_cache_->CommitChunk(op.request_id, const_cast(request->GetDeviceNode())); + hybrid_prefix_cache_->AcquireForRequest(op.request_id, op.extend_prefix_len, + op.extend_prefix_len + op.input_length); + hybrid_prefix_cache_->PopulateOp(op); + } return op; } @@ -443,18 +413,26 @@ DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::Schedu const bool need_bootstrap_token = request->Is() && config_.role == Role::kD; std::int32_t bootstrap_token = need_bootstrap_token ? request->GetLastToken() : -1; const std::int32_t first_pos = request->TokenSize(); + const bool came_from_prefill_done = request->Is(); auto op = applyDecodeEvent(request, std::move(event), config_.decode_input_tokens); if (need_bootstrap_token) { op.decode_input_id = bootstrap_token; } - acquirePagedCachePagesForRequest(op.request_id, first_pos, first_pos + op.input_length); - populatePagedCachePagesForOp(op); + // Order: attach, acquire, populate. + if (hybrid_prefix_cache_) { + if (came_from_prefill_done) { + hybrid_prefix_cache_->CommitChunk(op.request_id, const_cast(request->GetDeviceNode())); + } + hybrid_prefix_cache_->AcquireForRequest(op.request_id, first_pos, first_pos + op.input_length); + hybrid_prefix_cache_->PopulateOp(op); + } return op; } DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::ScheduleDecodeFromRetractedEvent event) { const std::int32_t mamba_cow_src_index = event.GetMatchResult().mamba_cow_src_index; + auto paged_cache_hit = event.GetMatchResult().paged_cache; request->Apply(std::move(event)); if (!request->Is()) { throw std::logic_error( @@ -483,9 +461,11 @@ DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::Schedu } } - releasePagedCachePagesForRequest(op.request_id); - acquirePagedCachePagesForRequest(op.request_id, 0, request->TokenSize()); - populatePagedCachePagesForOp(op); + if (hybrid_prefix_cache_) { + hybrid_prefix_cache_->ReleaseRequest(op.request_id); + hybrid_prefix_cache_->AcquireForRequest(op.request_id, 0, request->TokenSize(), paged_cache_hit); + hybrid_prefix_cache_->PopulateOp(op); + } return op; } @@ -514,7 +494,8 @@ Scheduler::newForwardOperation(std::vector candidates) { [](const ForwardOperation& op) { return std::holds_alternative(op); }); }; std::vector loadback_ops; - auto simulated_free = initialPagedCacheGroupSimulatedFree(); + auto simulated_free = + hybrid_prefix_cache_ ? hybrid_prefix_cache_->InitialSimulatedFree() : std::map{}; for (Request* request : candidates) { if (token_budget <= 0 || config_.max_batch_size == ops.size()) break; diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp b/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp index 7fcd4de8a..bf4e6cd68 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp @@ -20,12 +20,14 @@ #include "scheduler/scheduler.h" +#include #include #include #include #include #include #include +#include #include #include #include @@ -41,6 +43,7 @@ #include "fsm/forward_events.h" #include "fsm/forward_states.h" #include "resource/kv_prefix_cache/kv_prefix_cache.h" +#include "resource/radix_tree/radix_tree.h" #include "resource/radix_tree/tree_node.h" #include "scheduler/execution_event.h" #include "scheduler/operations/cache.h" @@ -49,15 +52,6 @@ #include "scheduler/request_spec.h" #include "scheduler/types.h" -namespace { - -std::int32_t CeilDivPositive(std::int32_t numer, std::int32_t denom) { - if (numer <= 0) return 0; - return (numer + denom - 1) / denom; -} - -} // namespace - namespace tokenspeed { Scheduler::Scheduler(SchedulerConfig config) @@ -76,23 +70,57 @@ Scheduler::Scheduler(SchedulerConfig config) if (config_.enable_kv_cache_events) { kv_prefix_cache_.SetKvEventSink([this](KvCacheEvent event) { kv_events_.push_back(std::move(event)); }); } - if (config_.enable_mamba && config_.mamba_pool_total_chunks > 0) { + const bool has_mamba_pool = config_.enable_mamba && config_.mamba_pool_total_chunks > 0; + if (has_mamba_pool) { mamba_allocator_.emplace(config_.mamba_pool_total_chunks); - if (config_.role != Role::kD) { - hybrid_prefix_cache_.emplace(kv_prefix_cache_, &*mamba_allocator_, config_.mamba_cache_chunk_size); - kv_prefix_cache_.GetDeviceManager().SetEvictionCallback( - [this](TreeNode* node) { hybrid_prefix_cache_->OnKVEvict(node); }); - } } - for (const auto& cfg : config_.paged_cache_groups) { - PagedCacheGroupConfig copy = cfg; - copy.Validate(); - std::string gid = copy.group_id; - auto [_, inserted] = - paged_cache_allocators_.emplace(gid, std::make_unique(std::move(copy))); - if (!inserted) { - throw std::invalid_argument("Scheduler: duplicate paged cache group_id: " + gid); + // Construct HybridPrefixCache when any adjunct/paged-cache feature is configured. + // Role::kD skips Mamba but still participates in paged-cache transport. + const bool has_mamba_adjunct = has_mamba_pool && config_.role != Role::kD; + const bool has_prefix_cache_adjunct = config_.prefix_cache_adjunct.has_value(); + const bool has_paged_cache_groups = !config_.paged_cache_groups.empty(); + if (has_mamba_adjunct || has_prefix_cache_adjunct || has_paged_cache_groups) { + MambaChunkAllocator* mamba_ptr = has_mamba_adjunct ? &*mamba_allocator_ : nullptr; + hybrid_prefix_cache_.emplace(kv_prefix_cache_, mamba_ptr, config_.mamba_cache_chunk_size); + kv_prefix_cache_.GetDeviceManager().SetEvictionCallback( + [this](TreeNode* node) { hybrid_prefix_cache_->OnKVEvict(node); }); + + for (const auto& cfg : config_.paged_cache_groups) { + PagedCacheGroupConfig copy = cfg; + copy.Validate(); + hybrid_prefix_cache_->RegisterPagedCacheGroup(std::make_unique(std::move(copy))); + } + + if (has_prefix_cache_adjunct) { + const auto& spec = *config_.prefix_cache_adjunct; + if (spec.required_groups.empty()) { + throw std::invalid_argument("Scheduler: prefix_cache_adjunct.required_groups must be non-empty"); + } + // HybridPrefixCache derives history alignment from the registered + // group configs; we still build the sliding-window map here. + std::unordered_map sliding_window_per_group; + for (const auto& gid : spec.required_groups) { + const PagedCacheGroupConfig* cfg = nullptr; + for (const auto& g : config_.paged_cache_groups) { + if (g.group_id == gid) { + cfg = &g; + break; + } + } + if (cfg == nullptr) { + throw std::invalid_argument("Scheduler: prefix_cache_adjunct required group_id '" + gid + + "' not found in paged_cache_groups"); + } + if (cfg->retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + if (!cfg->sliding_window_tokens.has_value() || *cfg->sliding_window_tokens <= 0) { + throw std::invalid_argument("Scheduler: prefix_cache_adjunct sliding group '" + gid + + "' must declare positive sliding_window_tokens"); + } + sliding_window_per_group.emplace(gid, *cfg->sliding_window_tokens); + } + } + hybrid_prefix_cache_->EnablePagedCacheAdjunct(spec.required_groups, std::move(sliding_window_per_group)); } } } @@ -190,196 +218,45 @@ std::size_t Scheduler::ActiveKvPages() const { } std::vector Scheduler::PagedCacheGroupIds() const { - std::vector ids; - ids.reserve(paged_cache_allocators_.size()); - for (const auto& [gid, _] : paged_cache_allocators_) { - ids.push_back(gid); - } - return ids; + if (!hybrid_prefix_cache_) return {}; + return hybrid_prefix_cache_->PagedCacheGroupIds(); } std::int32_t Scheduler::PagedCacheGroupTotalPages(const std::string& group_id) const { - auto it = paged_cache_allocators_.find(group_id); - if (it == paged_cache_allocators_.end()) { + if (!hybrid_prefix_cache_) { throw std::out_of_range("Scheduler::PagedCacheGroupTotalPages: group_id not configured"); } - return it->second->TotalPages(); + return hybrid_prefix_cache_->PagedCacheGroupTotalPages(group_id); } std::int32_t Scheduler::PagedCacheGroupAvailablePages(const std::string& group_id) const { - auto it = paged_cache_allocators_.find(group_id); - if (it == paged_cache_allocators_.end()) { + if (!hybrid_prefix_cache_) { throw std::out_of_range("Scheduler::PagedCacheGroupAvailablePages: group_id not configured"); } - return it->second->AvailablePages(); + return hybrid_prefix_cache_->PagedCacheGroupAvailablePages(group_id); } std::int64_t Scheduler::PagedCacheGroupFailedAllocCount(const std::string& group_id) const { - auto it = paged_cache_allocators_.find(group_id); - if (it == paged_cache_allocators_.end()) { + if (!hybrid_prefix_cache_) { throw std::out_of_range("Scheduler::PagedCacheGroupFailedAllocCount: group_id not configured"); } - return it->second->FailedAllocCount(); + return hybrid_prefix_cache_->PagedCacheGroupFailedAllocCount(group_id); } std::vector Scheduler::GetRequestPagedCachePageIds(const std::string& request_id, const std::string& group_id) const { - if (paged_cache_allocators_.find(group_id) == paged_cache_allocators_.end()) { + if (!hybrid_prefix_cache_) { throw std::out_of_range("Scheduler::GetRequestPagedCachePageIds: group_id not configured"); } - auto req_it = request_paged_cache_tables_.find(request_id); - if (req_it == request_paged_cache_tables_.end()) { - return {}; - } - auto group_it = req_it->second.find(group_id); - if (group_it == req_it->second.end()) { - return {}; - } - return group_it->second.PageIds(); + return hybrid_prefix_cache_->GetRequestPagedCachePageIds(request_id, group_id); } std::int32_t Scheduler::GetRequestPagedCacheBaseLogicalPage(const std::string& request_id, const std::string& group_id) const { - if (paged_cache_allocators_.find(group_id) == paged_cache_allocators_.end()) { + if (!hybrid_prefix_cache_) { throw std::out_of_range("Scheduler::GetRequestPagedCacheBaseLogicalPage: group_id not configured"); } - auto req_it = request_paged_cache_tables_.find(request_id); - if (req_it == request_paged_cache_tables_.end()) { - return 0; - } - auto group_it = req_it->second.find(group_id); - if (group_it == req_it->second.end()) { - return 0; - } - return group_it->second.BaseLogicalPage(); -} - -void Scheduler::acquirePagedCachePagesForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive) { - if (paged_cache_allocators_.empty()) return; - auto& tables = request_paged_cache_tables_[request_id]; - for (const auto& [group_id, allocator] : paged_cache_allocators_) { - auto it = tables.find(group_id); - if (it == tables.end()) { - it = tables.emplace(group_id, PagedCacheGroupTable(allocator.get())).first; - } - const auto& cfg = allocator->Config(); - if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { - const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); - it->second.ReleaseSkipped(lower); - } - it->second.Acquire(target_raw_tokens_exclusive); - } -} - -PagedCacheGroupAdmission Scheduler::checkPagedCacheGroupAdmission( - const std::string& request_id, std::int32_t first_raw_position_of_op, std::int32_t target_raw_tokens_exclusive, - const std::map& simulated_free) const { - PagedCacheGroupAdmission result; - if (paged_cache_allocators_.empty() || target_raw_tokens_exclusive < 0) { - return result; - } - - auto req_it = request_paged_cache_tables_.find(request_id); - for (const auto& [gid, allocator] : paged_cache_allocators_) { - const auto& cfg = allocator->Config(); - const std::int32_t raw_per_page = cfg.RawTokensPerPage(); - if (cfg.entry_stride_tokens <= 0 || cfg.rows_per_page <= 0 || raw_per_page <= 0) { - continue; - } - - const std::int32_t entries = CeilDivPositive(target_raw_tokens_exclusive, cfg.entry_stride_tokens); - const std::int32_t required = (entries + cfg.rows_per_page - 1) / cfg.rows_per_page; - - std::int32_t current_size = 0; - std::int32_t current_active = 0; - std::int32_t already_released = 0; - if (req_it != request_paged_cache_tables_.end()) { - auto t_it = req_it->second.find(gid); - if (t_it != req_it->second.end()) { - current_size = t_it->second.Size(); - current_active = t_it->second.ActivePagesCount(); - already_released = t_it->second.ReleasedPagesCount(); - } - } - - std::int32_t releasable = 0; - if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { - const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); - const std::int32_t target_releases = lower / raw_per_page; - releasable = std::max(0, target_releases - already_released); - releasable = std::min(releasable, current_active); - } - - // Absolute coverage = already_released (base) + live size. - const std::int32_t absolute_have = already_released + current_size; - const std::int32_t new_pages = std::max(0, required - absolute_have); - std::int32_t free = allocator->AvailablePages(); - auto sf_it = simulated_free.find(gid); - if (sf_it != simulated_free.end()) { - free = sf_it->second; - } - - result.releasable_pages[gid] = releasable; - result.new_pages_needed[gid] = new_pages; - if (free + releasable < new_pages) { - result.ok = false; - } - } - return result; -} - -std::map Scheduler::initialPagedCacheGroupSimulatedFree() const { - std::map out; - for (const auto& [gid, allocator] : paged_cache_allocators_) { - out[gid] = allocator->AvailablePages(); - } - return out; -} - -void Scheduler::applyPagedCacheGroupAdmissionDebit(std::map& simulated_free, - const PagedCacheGroupAdmission& admission) { - for (const auto& [gid, releasable] : admission.releasable_pages) { - simulated_free[gid] += releasable; - } - for (const auto& [gid, new_pages] : admission.new_pages_needed) { - simulated_free[gid] -= new_pages; - } -} - -void Scheduler::releasePagedCachePagesForRequest(const std::string& request_id) { - auto it = request_paged_cache_tables_.find(request_id); - if (it == request_paged_cache_tables_.end()) return; - for (auto& [_, table] : it->second) { - table.ReleaseAll(); - } - request_paged_cache_tables_.erase(it); -} - -// Snapshot the per-group page ids the request currently owns into op. -// For sliding groups page_ids are compact (live-only) and a base -// logical-page offset is emitted alongside; full-history groups omit the -// offset (implicit 0). -void Scheduler::populatePagedCachePagesForOp(ForwardOperationBase& op_base) const { - if (paged_cache_allocators_.empty()) { - return; - } - auto req_it = request_paged_cache_tables_.find(op_base.request_id); - for (const auto& [gid, allocator] : paged_cache_allocators_) { - std::vector pages; - std::int32_t base_offset = 0; - if (req_it != request_paged_cache_tables_.end()) { - auto table_it = req_it->second.find(gid); - if (table_it != req_it->second.end()) { - pages = table_it->second.PageIds(); - base_offset = table_it->second.BaseLogicalPage(); - } - } - op_base.paged_cache_pages[gid] = std::move(pages); - if (allocator->Config().retention == PagedCacheGroupConfig::Retention::SlidingWindow) { - op_base.paged_cache_page_base_offsets[gid] = base_offset; - } - } + return hybrid_prefix_cache_->GetRequestPagedCacheBaseLogicalPage(request_id, group_id); } std::int32_t Scheduler::GetRequestTokenSize(const std::string& id) const { @@ -421,9 +298,11 @@ ExecutionPlan Scheduler::NextExecutionPlan() { std::vector write_back_ops; write_back_ops = std::move(newWriteBackOperation(requests_)); - for (const auto& [id, req] : requests_) { - if (req->Is()) { - releasePagedCachePagesForRequest(id); + if (hybrid_prefix_cache_) { + for (const auto& [id, req] : requests_) { + if (req->Is()) { + hybrid_prefix_cache_->ReleaseRequest(id); + } } } std::erase_if(requests_, [](const auto& req) { return req.second->template Is(); }); diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.h b/tokenspeed-scheduler/csrc/scheduler/scheduler.h index e64a12c11..0e5407c97 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.h +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.h @@ -48,12 +48,6 @@ #include "fsm/pd_events.h" namespace tokenspeed { -struct PagedCacheGroupAdmission { - bool ok{true}; - std::map releasable_pages{}; - std::map new_pages_needed{}; -}; - class Scheduler { public: explicit Scheduler(SchedulerConfig config); @@ -79,9 +73,7 @@ class Scheduler { std::int64_t PagedCacheGroupFailedAllocCount(const std::string& group_id) const; std::vector GetRequestPagedCachePageIds(const std::string& request_id, const std::string& group_id) const; - // Compact-view base logical-page offset (column 0 of PageIds() = absolute - // logical page returned here). 0 for full-history groups and unseen - // request/group pairs. Tests use this to address compact tables. + // Compact-view base logical-page offset; 0 for full-history / unseen. std::int32_t GetRequestPagedCacheBaseLogicalPage(const std::string& request_id, const std::string& group_id) const; private: @@ -137,18 +129,6 @@ class Scheduler { private: SchedulerConfig config_; -private: - void acquirePagedCachePagesForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive); - PagedCacheGroupAdmission checkPagedCacheGroupAdmission( - const std::string& request_id, std::int32_t first_raw_position_of_op, std::int32_t target_raw_tokens_exclusive, - const std::map& simulated_free) const; - std::map initialPagedCacheGroupSimulatedFree() const; - static void applyPagedCacheGroupAdmissionDebit(std::map& simulated_free, - const PagedCacheGroupAdmission& admission); - void releasePagedCachePagesForRequest(const std::string& request_id); - void populatePagedCachePagesForOp(ForwardOperationBase& op_base) const; - private: PageAllocator device_allocator_; PageAllocator host_allocator_; @@ -156,8 +136,6 @@ class Scheduler { KVPrefixCache kv_prefix_cache_; ReqPoolAllocator req_pool_allocator_; std::optional hybrid_prefix_cache_{}; - std::map> paged_cache_allocators_; - std::unordered_map> request_paged_cache_tables_; private: std::unordered_map> requests_; diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index e48d042bf..9d7db855d 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -41,6 +42,8 @@ enum class DisaggregationMode { kPrefill, kDecode, }; +// `PagedCacheGroupFamily` and `StateRestorePolicy` are defined in +// resource/allocator/paged_cache_group.h (transitively included above). template class NodeRef; @@ -64,6 +67,12 @@ struct SchedulerStats { std::int64_t active_requests = 0; }; +// Opt-in spec for the paged-cache prefix-cache adjunct. Unset means paged-cache +// groups are transport-only (no snapshot chain, no prefix-cache reuse). +struct PrefixCacheAdjunctSpec { + std::vector required_groups{}; +}; + struct SchedulerConfig { std::int32_t page_size{}; struct { @@ -76,6 +85,9 @@ struct SchedulerConfig { std::vector paged_cache_groups{}; + // Unset means paged-cache groups are transport-only. + std::optional prefix_cache_adjunct{}; + std::int32_t max_scheduled_tokens{}; std::int32_t max_batch_size{}; std::int32_t decode_input_tokens{1}; diff --git a/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py b/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py index dc87ada88..f2070be85 100644 --- a/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py +++ b/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py @@ -27,8 +27,10 @@ ExecutionPlan, PagedCacheGroupAllocator, PagedCacheGroupConfig, + PagedCacheGroupFamily, PagedCacheGroupTable, PagedCacheRetention, + PrefixCacheAdjunctSpec, RequestSpec, Scheduler, SchedulerConfig, @@ -71,7 +73,9 @@ def _flat_forward_op_repr(self): "PagedCacheRetention", "PagedCacheGroupConfig", "PagedCacheGroupAllocator", + "PagedCacheGroupFamily", "PagedCacheGroupTable", + "PrefixCacheAdjunctSpec", # Execution plan & operations "ExecutionPlan", "Forward", diff --git a/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h b/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h new file mode 100644 index 000000000..2fc5ba956 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h @@ -0,0 +1,29 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +#pragma once + +// Test-only friend of HybridPrefixCache; exposes hooks needed to drive prune +// paths whose direct public surface is non-trivial to set up via AdmitChunk. + +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/radix_tree/tree_node.h" + +namespace tokenspeed { + +class HybridPrefixCacheTestPeer {}; + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h b/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h new file mode 100644 index 000000000..981cfcd6b --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h @@ -0,0 +1,188 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +#pragma once + +// Shared fixture for HybridPrefixCache + two-group paged-cache tests. + +#include + +#include +#include +#include +#include +#include + +#include "resource/allocator/owned_pages.h" +#include "resource/allocator/page_allocator.h" +#include "resource/allocator/paged_cache_group.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/kv_prefix_cache/kv_prefix_cache.h" +#include "resource/radix_tree/paged_cache_snapshot.h" +#include "resource/radix_tree/radix_tree.h" +#include "resource/radix_tree/tree_node.h" +#include "resource/types.h" +#include "unit_test_helper.h" + +namespace tokenspeed::test { + +struct PagedCacheFixtureParams { + std::int32_t page_size; + std::int32_t device_pages; + std::int32_t lcm_raw_tokens; + std::int32_t sliding_window_tokens; + std::int32_t fh_rows_per_page; + std::int32_t fh_stride; + std::int32_t swa_rows_per_page; + std::int32_t swa_stride; + std::int32_t group_total_pages; +}; + +template +class PagedCacheTestFixtureT : public ::testing::Test { +protected: + static constexpr std::int32_t kPageSize = kParams.page_size; + static constexpr std::int32_t kDevicePages = kParams.device_pages; + static constexpr std::int32_t kLcm = kParams.lcm_raw_tokens; + static constexpr std::int32_t kSlidingWindow = kParams.sliding_window_tokens; + + void SetUp() override { + device_alloc_ = std::make_unique(kPageSize, kDevicePages); + kv_cache_ = std::make_unique(device_alloc_.get(), /*host=*/nullptr); + + auto fh_owner = std::make_unique(MakeGroupConfig( + "fh", kParams.fh_rows_per_page, kParams.fh_stride, PagedCacheGroupConfig::Retention::FullHistory, + /*window=*/0, PagedCacheGroupFamily::History)); + auto swa_owner = std::make_unique(MakeGroupConfig( + "swa", kParams.swa_rows_per_page, kParams.swa_stride, PagedCacheGroupConfig::Retention::SlidingWindow, + kSlidingWindow, PagedCacheGroupFamily::State)); + fh_alloc_ = fh_owner.get(); + swa_alloc_ = swa_owner.get(); + + hybrid_ = std::make_unique(*kv_cache_, /*mamba=*/nullptr, + /*mamba_chunk_size=*/0); + hybrid_->RegisterPagedCacheGroup(std::move(fh_owner)); + hybrid_->RegisterPagedCacheGroup(std::move(swa_owner)); + std::unordered_map sliding{{"swa", kSlidingWindow}}; + hybrid_->EnablePagedCacheAdjunct(/*required=*/{"fh", "swa"}, std::move(sliding)); + kv_cache_->GetDeviceManager().SetEvictionCallback([this](TreeNode* node) { hybrid_->OnKVEvict(node); }); + } + + // Insert pages from `start_node` (nullptr=root); returns terminal node. + TreeNode* InsertDevicePages(std::int32_t num_pages, token_t token_start, TreeNode* start_node = nullptr) { + auto tokens = MakeAlignedTokens(num_pages, kPageSize, token_start); + OwnedPages pages = device_alloc_->Allocate(num_pages); + auto res = kv_cache_->Insert(tokens, /*prefix_pages=*/{}, std::move(pages), + /*page_hashes=*/{}, start_node); + return res.last_node; + } + + // Build a complete snapshot covering one LCM segment ending at prefix_len_tokens. + std::unique_ptr MakeCompleteSnapshot(std::int32_t prefix_len_tokens, + std::int32_t swa_base_logical_page = 0) { + auto snap = std::make_unique(); + snap->prefix_len_tokens = prefix_len_tokens; + snap->groups.emplace("fh", BuildGroupSnap(fh_alloc_, prefix_len_tokens, + /*base=*/0, /*sliding=*/false)); + snap->groups.emplace("swa", + BuildGroupSnap(swa_alloc_, prefix_len_tokens, swa_base_logical_page, /*sliding=*/true)); + return snap; + } + + // History-only snapshot (state group omitted); used for fallback tests. + std::unique_ptr MakeHistoryOnlySnapshot(std::int32_t prefix_len_tokens) { + auto snap = std::make_unique(); + snap->prefix_len_tokens = prefix_len_tokens; + snap->groups.emplace("fh", BuildGroupSnap(fh_alloc_, prefix_len_tokens, + /*base=*/0, /*sliding=*/false)); + return snap; + } + + // Detach and reattach without the state group; re-attach recomputes + // `complete_families` and leaves only History present. + void DowngradeSnapshotToHistoryOnly(TreeNode* node) { + auto snap = hybrid_->DetachPagedCacheSnapshotFromNode(node); + ASSERT_NE(snap, nullptr); + snap->groups.erase("swa"); + hybrid_->AttachPagedCacheSnapshotToNode(node, std::move(snap)); + } + + std::unique_ptr device_alloc_; + std::unique_ptr kv_cache_; + PagedCacheGroupAllocator* fh_alloc_{nullptr}; + PagedCacheGroupAllocator* swa_alloc_{nullptr}; + std::unique_ptr hybrid_; + +protected: + static PagedCacheGroupConfig MakeGroupConfig(std::string group_id, std::int32_t rows_per_page, std::int32_t stride, + PagedCacheGroupConfig::Retention retention, std::int32_t window, + PagedCacheGroupFamily family) { + PagedCacheGroupConfig cfg{}; + cfg.group_id = std::move(group_id); + cfg.rows_per_page = rows_per_page; + cfg.entry_stride_tokens = stride; + cfg.total_pages = kParams.group_total_pages; + cfg.retention = retention; + cfg.sliding_window_tokens = window; + cfg.family = family; + return cfg; + } + +private: + PagedCacheGroupSnapshot BuildGroupSnap(PagedCacheGroupAllocator* alloc, std::int32_t prefix_len_tokens, + std::int32_t base_logical_page, bool sliding) { + PagedCacheGroupTable t{alloc}; + t.Acquire(kLcm); + // Caller chooses absolute base; fresh table commits at 0. + auto committed = sliding ? t.CheckpointStateToSnapshot(kLcm) : t.CommitHistoryToSnapshot(kLcm); + PagedCacheGroupSnapshot g{}; + g.pages = std::move(committed.pages); + g.base_logical_page = base_logical_page; + g.raw_token_cursor = prefix_len_tokens; + g.sliding = sliding; + return g; + } +}; + +inline constexpr PagedCacheFixtureParams kSmallFixtureParams{ + /*page_size=*/2, /*device_pages=*/8, + /*lcm_raw_tokens=*/4, /*sliding_window_tokens=*/8, + /*fh_rows_per_page=*/4, /*fh_stride=*/1, + /*swa_rows_per_page=*/2, /*swa_stride=*/1, + /*group_total_pages=*/16, +}; +using PagedCacheSmallFixture = PagedCacheTestFixtureT; + +inline constexpr PagedCacheFixtureParams kLargeFixtureParams{ + /*page_size=*/64, /*device_pages=*/64, + /*lcm_raw_tokens=*/256, /*sliding_window_tokens=*/128, + /*fh_rows_per_page=*/64, /*fh_stride=*/4, + /*swa_rows_per_page=*/64, /*swa_stride=*/1, + /*group_total_pages=*/32, +}; +using PagedCacheLargeFixture = PagedCacheTestFixtureT; + +// Wide-window variant: state window > history alignment so `segments_needed=2`. +inline constexpr PagedCacheFixtureParams kWideWindowFixtureParams{ + /*page_size=*/64, /*device_pages=*/64, + /*lcm_raw_tokens=*/256, /*sliding_window_tokens=*/512, + /*fh_rows_per_page=*/64, /*fh_stride=*/4, + /*swa_rows_per_page=*/64, /*swa_stride=*/1, + /*group_total_pages=*/64, +}; +using PagedCacheWideWindowFixture = PagedCacheTestFixtureT; + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_attach_loop.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_attach_loop.cpp new file mode 100644 index 000000000..9d250edf5 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_attach_loop.cpp @@ -0,0 +1,110 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +// Coverage: end-to-end borrowed-prefix re-import on a fully-cached prefill. + +#include + +#include + +#include "integration_test_helper.h" + +namespace tokenspeed::test { +namespace { + +// page=2, LCM=4 raw tokens (2 KV pages per segment); 12-token prompt spans 3 segments. +class PagedCacheAttachLoopTest : public SchedulerTestSuite { +protected: + static constexpr std::int32_t kLcmRawTokens = 4; + + SchedulerConfig MakeConfig() override { + auto cfg = SchedulerTestSuite::MakeConfig(); + cfg.page_size = 2; + cfg.device_allocator.total_pages = 64; + cfg.host_allocator.total_pages = 64; + cfg.max_scheduled_tokens = 16; + cfg.max_batch_size = 8; + cfg.enable_l3_storage = false; + + PagedCacheGroupConfig fh{}; + fh.group_id = "fh"; + fh.rows_per_page = 4; + fh.entry_stride_tokens = 1; + fh.total_pages = 32; + fh.retention = PagedCacheGroupConfig::Retention::FullHistory; + cfg.paged_cache_groups.push_back(fh); + + PagedCacheGroupConfig swa{}; + swa.group_id = "swa"; + swa.rows_per_page = 2; + swa.entry_stride_tokens = 1; + swa.total_pages = 32; + swa.retention = PagedCacheGroupConfig::Retention::SlidingWindow; + swa.sliding_window_tokens = 8; + cfg.paged_cache_groups.push_back(swa); + + // Enable prefix-cache adjunct (LCM and sliding window derived from groups). + PrefixCacheAdjunctSpec spec{}; + spec.required_groups = {"fh", "swa"}; + cfg.prefix_cache_adjunct = spec; + + return cfg; + } + + static const FlatForwardOperation* GetForwardOp(const ExecutionPlan& plan) { + for (const auto& op : plan.Operations()) { + if (auto* f = std::get_if(&op)) return f; + } + return nullptr; + } +}; + +} // namespace + +// R1 primes 12 tokens; R2's same prefix must skip commit, re-import borrowed +// pages, and populate FlatForwardOperation.paged_cache_block_tables. +TEST_F(PagedCacheAttachLoopTest, FullyCachedPrefillBorrowedPrefixReimported) { + // R1 primes the cache with 12 tokens. + Submit(MakeRequestSpec("r1", /*num_pages=*/6, /*start=*/1)); + PlanOnce(); + SendForwardDone("r1", {99}); + PlanOnce(); + SendFinish("r1"); + PlanOnce(); + + // R2 uses the same prefix and should import all 3 LCM segments. + // borrowed. + Submit(MakeRequestSpec("r2", /*num_pages=*/6, /*start=*/1)); + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + + // (a) prefix-hit covers at least one LCM segment. + EXPECT_GE(fwd->extend_prefix_lens[0], kLcmRawTokens); + + // (b) per-group tables already contain borrowed pages. + auto fh_ids = scheduler_->GetRequestPagedCachePageIds("r2", "fh"); + EXPECT_GE(fh_ids.size(), 1u) << "borrowed fh prefix must be imported"; + + // (c) paged_cache_block_tables populated for the executor. + EXPECT_FALSE(fwd->paged_cache_block_tables.empty()); + auto fh_it = fwd->paged_cache_block_tables.find("fh"); + ASSERT_NE(fh_it, fwd->paged_cache_block_tables.end()); + EXPECT_FALSE(fh_it->second.empty()); + EXPECT_FALSE(fh_it->second[0].empty()) << "fh block table row must not be empty for cached prefill"; +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp new file mode 100644 index 000000000..178a400e3 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +// Coverage: passive snapshot detach on KV LRU eviction returns pages via RAII. + +#include "paged_cache_test_fixture.h" + +namespace tokenspeed::test { + +using PagedCacheEvictionTest = PagedCacheSmallFixture; + +// Two branches, evict A with snapshot attached; pages must return via RAII. +// +// The contract under test is purely observable: after the snapshot-bearing KV +// node is evicted, every paged-cache page the snapshot was holding must be +// back in its group allocator's free list. We deliberately avoid hard-coding +// "how many pages a snapshot consumes" (that depends on commit-time semantics +// such as State-window trim) — instead, we capture the allocator state right +// before attach and require it to be restored after eviction. +TEST_F(PagedCacheEvictionTest, PassiveEvictionReleasesPagedCachePages) { + InsertDevicePages(/*num_pages=*/2, /*token_start=*/1); // branch A + auto* leaf_b = InsertDevicePages(/*num_pages=*/2, /*token_start=*/100); // branch B + ASSERT_NE(leaf_b, nullptr); + + TreeNode* attach_a = kv_cache_->GetRadixTree().SplitAt( + kv_cache_->Match(MakeAlignedTokens(2, kPageSize, /*start=*/1)).device.last_node, kLcm); + ASSERT_NE(attach_a, nullptr); + + // Baseline: paged-cache pools fully free, no snapshot attached. + const std::int32_t fh_before = fh_alloc_->AvailablePages(); + const std::int32_t swa_before = swa_alloc_->AvailablePages(); + + hybrid_->AttachPagedCacheSnapshotToNode(attach_a, MakeCompleteSnapshot(kLcm)); + EXPECT_TRUE(attach_a->HasPagedCacheSnapshot()); + // The snapshot must hold *some* pages from each group, otherwise the test + // below ("eviction returns them") is vacuous. We do NOT assert the exact + // count — that depends on snapshot-build semantics (e.g. State-window + // trim) and is covered by dedicated build/commit tests. + EXPECT_LT(fh_alloc_->AvailablePages(), fh_before); + EXPECT_LT(swa_alloc_->AvailablePages(), swa_before); + + // Pin branch B so eviction targets A. Without this lock the LRU policy + // could evict either branch. + auto match_b = kv_cache_->Match(MakeAlignedTokens(2, kPageSize, /*start=*/100)); + DeviceNodeRef ref_b{match_b.device.last_node}; + + // Force eviction of branch A by demanding one more page than the device + // allocator currently has free. Branch B's 2 device pages are pinned by + // `ref_b`, so the LRU must drop branch A — which carries our snapshot. + const std::int32_t target_available = device_alloc_->AvailablePages() + 1; + const bool ok = kv_cache_->EnsureCapacityByEvict(target_available); + EXPECT_TRUE(ok); + // Note: after eviction `attach_a` may be freed by tree pruning, so we do + // not dereference it. The observable proof that OnKVEvict detached the + // snapshot is that the paged-cache allocator pools are restored below. + + // Observable contract: every paged-cache page the snapshot held is now + // back in its allocator's free list (OwnedPages RAII via OnKVEvict -> + // DetachPagedCacheSnapshotFromNode). + EXPECT_EQ(fh_alloc_->AvailablePages(), fh_before); + EXPECT_EQ(swa_alloc_->AvailablePages(), swa_before); +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp new file mode 100644 index 000000000..aee02ae3b --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp @@ -0,0 +1,120 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +// Coverage: V4 family-split prefix match. History (chain) and State +// (trailing window) families are now scanned independently; a missing +// State snapshot falls back to a shallower depth without killing the +// History chain. + +#include "hybrid_prefix_cache_test_peer.h" +#include "paged_cache_test_fixture.h" + +namespace tokenspeed::test { + +using PagedCacheFamilySplitTest = PagedCacheLargeFixture; +using PagedCacheFamilyWideWindowTest = PagedCacheWideWindowFixture; + +// kSlidingWindow=128 < kLcm=256 -> segments_needed=1. +// Dropping state-completeness at the deepest boundary falls back one segment. +TEST_F(PagedCacheFamilySplitTest, HistoryCompleteStateMissingFallback) { + const std::int32_t num_pages = 768 / kPageSize; // 12 pages + TreeNode* terminal = InsertDevicePages(num_pages, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + + TreeNode* n256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); + TreeNode* n512 = kv_cache_->GetRadixTree().SplitAt(terminal, 512); + TreeNode* n768 = kv_cache_->GetRadixTree().SplitAt(terminal, 768); + ASSERT_NE(n256, nullptr); + ASSERT_NE(n512, nullptr); + ASSERT_NE(n768, nullptr); + + hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); + hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); + hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + + // Downgrade only the deepest snapshot: history-only at 768. + DowngradeSnapshotToHistoryOnly(n768); + ASSERT_TRUE(n768->HasPagedCacheSnapshot()); + EXPECT_TRUE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); + EXPECT_FALSE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); + + auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + ASSERT_NE(match.paged_cache.last_node, nullptr); + // History chain reaches 768 but state at 768 is missing; segments_needed=1 + // forces fallback to 512. + EXPECT_EQ(match.paged_cache.last_node, n512); + EXPECT_EQ(match.paged_cache.prefix_len_tokens, 512); +} + +// segments_needed=2 (window=512, align=256). State missing at 512 breaks +// both end_idx=2 (trailing 512+768) and end_idx=1 (trailing 256+512); only +// end_idx=0 (single segment 256) remains. +TEST_F(PagedCacheFamilyWideWindowTest, StateWindowDiscontinuityFallback) { + const std::int32_t num_pages = 768 / kPageSize; + TreeNode* terminal = InsertDevicePages(num_pages, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + + TreeNode* n256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); + TreeNode* n512 = kv_cache_->GetRadixTree().SplitAt(terminal, 512); + TreeNode* n768 = kv_cache_->GetRadixTree().SplitAt(terminal, 768); + ASSERT_NE(n256, nullptr); + ASSERT_NE(n512, nullptr); + ASSERT_NE(n768, nullptr); + + hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); + hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); + hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + + DowngradeSnapshotToHistoryOnly(n512); + + auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + ASSERT_NE(match.paged_cache.last_node, nullptr); + EXPECT_EQ(match.paged_cache.last_node, n256); + EXPECT_EQ(match.paged_cache.prefix_len_tokens, 256); +} + +// segments_needed=1: detaching state at mid-chain does not break the history +// chain; deepest state-complete boundary (768) remains usable. +TEST_F(PagedCacheFamilySplitTest, StateDetachDoesNotBreakHistoryChain) { + const std::int32_t num_pages = 768 / kPageSize; + TreeNode* terminal = InsertDevicePages(num_pages, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + + TreeNode* n256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); + TreeNode* n512 = kv_cache_->GetRadixTree().SplitAt(terminal, 512); + TreeNode* n768 = kv_cache_->GetRadixTree().SplitAt(terminal, 768); + ASSERT_NE(n256, nullptr); + ASSERT_NE(n512, nullptr); + ASSERT_NE(n768, nullptr); + + hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); + hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); + hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + + DowngradeSnapshotToHistoryOnly(n512); + ASSERT_TRUE(n512->HasPagedCacheSnapshot()); + EXPECT_TRUE(n512->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); + EXPECT_FALSE(n512->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); + EXPECT_TRUE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); + + auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + ASSERT_NE(match.paged_cache.last_node, nullptr); + // History chain unbroken; state at 768 (only the trailing segment) is fine. + EXPECT_EQ(match.paged_cache.last_node, n768); + EXPECT_EQ(match.paged_cache.prefix_len_tokens, 768); +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp new file mode 100644 index 000000000..049113942 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp @@ -0,0 +1,98 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +// Coverage: a State-family CheckpointStateToSnapshot following a prefix-cache +// hit on a wide window (window >= LCM) must not throw. Each State snapshot +// stores only its own LCM segment's owned delta; the trailing window is +// reconstructed across the snapshot chain at match time. Regression for the +// "not enough owned pages for window" overflow that incorrectly conflated +// "this commit's delta" with "the whole trailing window". + +#include "paged_cache_test_fixture.h" + +namespace tokenspeed::test { + +using PagedCachePrefixHitCommitTest = PagedCacheWideWindowFixture; + +// kSlidingWindow=512, kLcm=256 -> window = 2*LCM. With a 256-token prefix-hit, +// the import covers exactly the first LCM segment as borrowed; the second LCM +// segment must commit through CheckpointStateToSnapshot without trying to +// claim the borrowed half from owned_pages_. +TEST_F(PagedCachePrefixHitCommitTest, PrefixHitFollowedByCheckpointDoesNotOverflowWindow) { + static_assert(kSlidingWindow >= kLcm, "this test exercises window >= LCM"); + + // Seed a chain reaching 512 tokens and attach a complete snapshot at 256 + // so the second request gets a one-LCM-segment prefix-cache hit. + const std::int32_t num_pages = 512 / kPageSize; // 8 pages + TreeNode* terminal = InsertDevicePages(num_pages, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + + TreeNode* n256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); + ASSERT_NE(n256, nullptr); + hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); + ASSERT_TRUE(n256->HasPagedCacheSnapshot()); + + const auto tokens = MakeAlignedTokens(num_pages, kPageSize, /*start=*/1); + + // The second request: prefix-cache match returns the depth-256 hit. + auto pre_match = hybrid_->Match(tokens); + ASSERT_NE(pre_match.paged_cache.last_node, nullptr); + EXPECT_EQ(pre_match.paged_cache.last_node, n256); + EXPECT_EQ(pre_match.paged_cache.prefix_len_tokens, 256); + + // Import borrowed prefix + acquire fresh pages for the remaining LCM segment. + const std::string request_id = "r-prefix-hit"; + hybrid_->AcquireForRequest(request_id, + /*first_raw_position_of_op=*/256, + /*target_raw_tokens_exclusive=*/512, pre_match.paged_cache); + + // Trigger CheckpointStateToSnapshot at the next LCM boundary. Pre-fix this + // throws std::logic_error("not enough owned pages for window"); post-fix it + // commits only the new LCM segment's delta to the snapshot. + ASSERT_NO_THROW(hybrid_->CommitChunk(request_id, terminal)); + + // After commit, n512 (=terminal) must hold a complete snapshot covering + // both required families. + ASSERT_TRUE(terminal->HasPagedCacheSnapshot()); + const auto* committed_snap = terminal->GetPagedCacheSnapshot(); + ASSERT_NE(committed_snap, nullptr); + EXPECT_TRUE(committed_snap->IsCompleteFor(PagedCacheGroupFamily::History)); + EXPECT_TRUE(committed_snap->IsCompleteFor(PagedCacheGroupFamily::State)); + + // Observable: a fresh Match now reconstructs the full trailing window + // (state_span = [n256, n512]) and exposes window/raw_per_page page ids + // for the sliding "swa" group. + auto post_match = hybrid_->Match(tokens); + ASSERT_NE(post_match.paged_cache.last_node, nullptr); + EXPECT_EQ(post_match.paged_cache.prefix_len_tokens, 512); + + auto swa_it = post_match.paged_cache.per_group_page_ids.find("swa"); + ASSERT_NE(swa_it, post_match.paged_cache.per_group_page_ids.end()); + const auto& swa_ids = swa_it->second; + ASSERT_FALSE(swa_ids.empty()); + + const PagedCacheGroupSnapshot& swa_at_256 = n256->GetPagedCacheSnapshot()->groups.at("swa"); + const std::int32_t raw_per_page = swa_at_256.pages.Size() > 0 ? (kLcm / swa_at_256.pages.Size()) : 0; + ASSERT_GT(raw_per_page, 0); + const std::int32_t committed_depth = post_match.paged_cache.prefix_len_tokens; + const std::int32_t expected_state_pages = std::min(kSlidingWindow / raw_per_page, committed_depth / raw_per_page); + EXPECT_EQ(static_cast(swa_ids.size()), expected_state_pages); + + // Clean up the request tables; owned pages return via RAII / ReleaseAll. + hybrid_->ReleaseRequest(request_id); +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp new file mode 100644 index 000000000..0c2f7de84 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + +// Coverage: HybridPrefixCache::Match paged-cache adjunct branch. + +#include "paged_cache_test_fixture.h" + +namespace tokenspeed::test { + +using PagedCachePrefixMatchTest = PagedCacheLargeFixture; + +// 320 tokens: no snapshot caps to root; snapshot at 256 caps to 256. +TEST_F(PagedCachePrefixMatchTest, CapVsNoCap320) { + const std::int32_t num_pages = 320 / kPageSize; // 5 pages + TreeNode* terminal = InsertDevicePages(num_pages, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + EXPECT_EQ(terminal->DepthInTokens(), 320u); + + const auto tokens = MakeAlignedTokens(num_pages, kPageSize, /*start=*/1); + + // No snapshot: paged_cache empty; device/host capped to root. + auto match = hybrid_->Match(tokens); + EXPECT_EQ(match.paged_cache.last_node, nullptr); + EXPECT_EQ(match.paged_cache.prefix_len_tokens, 0); + ASSERT_NE(match.device.last_node, nullptr); + EXPECT_TRUE(match.device.last_node->IsRoot()) + << "device terminal must be capped to root when adjunct is enabled but no snapshot exists"; + ASSERT_NE(match.host.last_node, nullptr); + EXPECT_TRUE(match.host.last_node->IsRoot()) + << "host terminal must be capped to root when adjunct is enabled but no snapshot exists"; + + // A complete paged-cache snapshot at depth 256 caps to 256. + TreeNode* boundary_256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); + ASSERT_NE(boundary_256, nullptr); + EXPECT_EQ(boundary_256->DepthInTokens(), 256u); + hybrid_->AttachPagedCacheSnapshotToNode(boundary_256, MakeCompleteSnapshot(256)); + ASSERT_TRUE(boundary_256->HasPagedCacheSnapshot()); + EXPECT_TRUE(boundary_256->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); + EXPECT_TRUE(boundary_256->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); + + match = hybrid_->Match(tokens); + ASSERT_NE(match.paged_cache.last_node, nullptr); + EXPECT_EQ(match.paged_cache.last_node, boundary_256); + EXPECT_EQ(match.paged_cache.prefix_len_tokens, 256); + ASSERT_NE(match.device.last_node, nullptr); + EXPECT_EQ(match.device.last_node->DepthInTokens(), 256u) + << "device terminal must be capped to the deepest contiguous paged-cache node"; +} + +// Snapshots at 256/512/768; detaching 512 makes Match fall back to 256. +TEST_F(PagedCachePrefixMatchTest, ContiguousChainBreakMid) { + const std::int32_t num_pages = 768 / kPageSize; // 12 pages + TreeNode* terminal = InsertDevicePages(num_pages, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + + TreeNode* n256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); + TreeNode* n512 = kv_cache_->GetRadixTree().SplitAt(terminal, 512); + TreeNode* n768 = kv_cache_->GetRadixTree().SplitAt(terminal, 768); + ASSERT_NE(n256, nullptr); + ASSERT_NE(n512, nullptr); + ASSERT_NE(n768, nullptr); + + hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); + hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); + hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + + // Drop the middle snapshot; chain scan must stop at the gap. + auto dropped = hybrid_->DetachPagedCacheSnapshotFromNode(n512); + EXPECT_TRUE(dropped != nullptr); + EXPECT_FALSE(n512->HasPagedCacheSnapshot()); + ASSERT_TRUE(n768->HasPagedCacheSnapshot()); + EXPECT_TRUE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); + + auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + ASSERT_NE(match.paged_cache.last_node, nullptr); + EXPECT_EQ(match.paged_cache.last_node, n256); + EXPECT_EQ(match.paged_cache.prefix_len_tokens, 256); +} + +} // namespace tokenspeed::test