[runtime, ops] feat: LRU prefix KV cache for cross-request prefill reuse#25
[runtime, ops] feat: LRU prefix KV cache for cross-request prefill reuse#25Ziyi-Wang wants to merge 5 commits into
Conversation
ea8560a to
ea5f3f4
Compare
There was a problem hiding this comment.
Code Review
This pull request implements a content-addressed prefix cache for the KV cache manager, allowing for token-id sequence reuse across requests to improve performance. The changes include a new hashing mechanism for full blocks, LRU eviction logic, and integration with the scheduler to skip redundant prefill computations. My feedback focuses on a critical race condition between the control channel and the generation loop, potential performance bottlenecks in the scheduling loop due to repeated hashing, and the risks of using Python's built-in hash function for content-addressed caching.
| def receive_weights(self): | ||
| """Receive new model weights. | ||
|
|
||
| Active requests are preempted (their KV is stale under the new weights | ||
| and decode would read garbage) and the prefix cache index is dropped. | ||
| Re-prefill happens automatically on the next schedule(). | ||
| """ | ||
| self.scheduler.reset_for_state_change() | ||
| super().receive_weights() | ||
|
|
||
| def sleep(self, tag: str = None): | ||
| """Pause memory-saving regions. | ||
|
|
||
| The KV cache region is allocated with enable_cpu_backup=False, so its | ||
| bytes are dropped on sleep — any active block IDs become invalid. Active | ||
| requests are preempted; the prefix cache index is dropped. | ||
| """ | ||
| self.scheduler.reset_for_state_change() | ||
| super().sleep(tag=tag) | ||
|
|
||
| def load_state_dict(self, state_dict): | ||
| """In-process weight load (e.g. tests). Same staleness as receive_weights().""" | ||
| self.scheduler.reset_for_state_change() | ||
| super().load_state_dict(state_dict) |
There was a problem hiding this comment.
The methods receive_weights, sleep, and load_state_dict introduce a significant race condition. These methods are typically called from the control channel's RPC thread, while the generation loop (_generation_loop) runs in a separate background thread.
Calling self.scheduler.reset_for_state_change() concurrently with self.scheduler.schedule() or self.scheduler.update() can lead to inconsistent state, such as KeyError in the KV cache manager or corrupted request states.
Furthermore, there is a correctness issue: receive_weights clears the cache and preempts requests, but the generation loop continues to run while super().receive_weights() is blocking to receive new weights. This allows the scheduler to re-activate requests and fill the cache with KV tensors computed using the old weights, which will be stale once the weight update completes.
You should implement a synchronization mechanism (e.g., a threading.Lock) to ensure that control commands and the generation loop do not run concurrently, or pause the generation loop during state changes.
| # _activate_request → commit_prefix_plan (no rehash) and later to | ||
| # mark_blocks_filled (also no rehash). Eliminates the triple-hash | ||
| # that the original peek/allocate/stamp paths had. | ||
| block_hashes, num_prefix_hit_blocks = self._kv_cache_manager.plan_prefix_cache(request.input_ids_list) |
There was a problem hiding this comment.
Calling plan_prefix_cache inside the scheduling loop for every request in the queue can become a performance bottleneck, especially with long prompts and large queues. Since the block hashes for a given input_ids_list are deterministic, they should be cached on the InferenceRequest object. Only the num_prefix_hit_blocks (which depends on the current cache state) needs to be re-evaluated during scheduling.
| prev_hash = self._SEED | ||
|
|
||
| for i in range(num_full): | ||
| block_hash = hash((prev_hash, tuple(token_ids[i * page_size : (i + 1) * page_size]))) |
There was a problem hiding this comment.
Using Python's built-in hash() for content-addressed caching is risky because it is not guaranteed to be collision-resistant and its output is randomized across process restarts (unless PYTHONHASHSEED is fixed). While the collision probability for 1024 blocks is low, it increases quadratically with the number of blocks. For a more robust implementation, especially as the cache size grows, consider using a faster non-cryptographic hash like xxhash or a truncated cryptographic hash.
receive_weights / sleep / load_state_dict now call clear_cache_index() directly instead of routing through scheduler.reset_for_state_change(), which preempted active requests and reset stats. Drop the preempt-all helper entirely. Rationale: the framework (verl) drains all in-flight requests before triggering a weight update or sleep, so there is nothing to preempt at that point. clear_cache_index alone is enough to prevent future requests from hitting KV computed under old weights. Stats become pure lifetime counters, giving a stable trend line across weight updates instead of a jittery per-step ratio. Docstrings updated to spell out the drain assumption. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
KVCacheManager is a scheduler-level resource manager (refcount + LRU + prefix cache index), not a batch-invariant kernel. Its tests belong next to test_scheduler.py, not in tests/batch_invariant_ops/ which is for flex_attention / matmul / FA4 invariance checks. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Block hashes are a pure function of (token_ids, page_size, _SEED), so recomputing them every scheduling iteration for a requeued request is wasted work — gets noticeable at long contexts (page_size=16 + 32K prompts: ~1ms per replan × 100 retries = 100ms). Split the API: - compute_block_hashes(token_ids) → list[int] (stateless, deterministic) - count_prefix_hits(block_hashes) → int (cheap O(N) dict lookups) Scheduler caches the hashes on InferenceRequest.prefix_block_hashes after the first compute; subsequent rounds only re-run count_prefix_hits. The field already existed for mark_blocks_filled, so no new state is added. preempt() already clears it (input_ids_list grows on preempt, hashes need to be recomputed against the new sequence). plan_prefix_cache (the convenience wrapper) is removed entirely — no production caller after this change, so keeping it as a test-only wrapper would just hide that tests aren't exercising the real path. Addresses Gemini review comment on plan-per-iteration cost. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
PR description
What does this PR do?
Adds a content-addressed, refcounted, LRU-evicted KV prefix cache so requests sharing a leading prefix (system prompt, chat history, etc.) skip prefill on those tokens. On Qwen2.5-7B with a 1024-token shared system prompt, observed +25–38% rps vs. cache off. PP > 1 transparently falls back to the original allocator (no behavior change there).
Checklist Before Starting
[{modules}] {type}: {description}—[runtime, ops] feat: ..., validated againstscripts/ci/check_pr_title.py.Test
Qwen2.5-7B, H100×1, 64 ShareGPT requests with a 1024-token shared system prompt:
tests/batch_invariant_ops/test_kv_cache_manager.py— 26 new tests covering plan / commit / mark_blocks_filled / OOM rollback / refcount sharing / LRU eviction / hash-only invalidation / disabled-cache fallback. All pass.tests/test_scheduler.py13/13 pass (one test's setup updated since_activate_requestnow derivesnum_computed_tokensfrom a plan handed in by the caller).Design & Code Changes
Content-addressed LRU prefix cache. Full KV blocks are chain-hashed (
h_i = hash(h_{i-1}, tokens_i)); allocation looks the hash up and either increfs the existing block or takes the least-recently-released one from a free-LRU pool. Block lifetime is governed by refcount: in-use blocks are uneviciable, freed blocks keep their hash and remain hittable until LRU pressure reclaims them. Hashes are committed at prefill completion, not at allocation, which means two concurrent same-prefix requests will both miss the first time (one fills, the rest hit thereafter) — we trade one duplicate miss for not building a "filling" state machine. The manager exposes three primitives —plan_prefix_cache,commit_prefix_plan,mark_blocks_filled— structured so each block hash is computed exactly once per request and threaded through the scheduler, eliminating the triple-hash an earlier draft had.Integration is driver-local and event-driven.
KVCacheManagerlives only on the driver; non-driver PP ranks are unchanged. The scheduler callsplan_prefix_cachewhile sizing chunked-prefill, hands the plan to_activate_request → commit_prefix_plan, and stamps viamark_blocks_filledat prefill completion (idempotent, so no per-request flag). Invalidation is the surgical bit:receive_weights/sleep/load_state_dictcallclear_cache_index(), which drops only the hash → block lookup — refcounts and the block pool are untouched, so in-flight requests keep their blocks (no premature free) and future requests cleanly miss instead of hitting stale KV computed under old weights. Observability is a single collective RPC (get_prefix_cache_stats) with a base-class stub so it works regardless of rank topology.Scope is deliberately tight. ~840 LOC across 11 files; the manager rewrite is the bulk and the rest is thin scheduler / RPC plumbing plus a benchmark switch. PP > 1 falls back to the original allocator transparently — the driver-side index is rank-agnostic, but verifying KV-tensor coherence across ranks deserves its own MR (est. < 50 LOC + tests). Two other deferrals: eager stamping (would let concurrent same-prefix requests share the first miss; needs a "filling" state machine, unnecessary for RL-style warm-then-batch workloads) and a stronger hash (
xxhash/blake2b; Python'shash()collision probability at 1024 blocks is ~3×10⁻¹⁴, revisit at 10⁵+ blocks).Checklist Before Submitting
pre-commit run)tests/batch_invariant_ops/test_kv_cache_manager.py(26 new tests) covers the manager API surface;tests/test_scheduler.pystill green; bitwise-equivalence + throughput A/B validated viabenchmarks/throughput.py.