Skip to content

[runtime, ops] feat: LRU prefix KV cache for cross-request prefill reuse#25

Open
Ziyi-Wang wants to merge 5 commits into
verl-project:mainfrom
Ziyi-Wang:feat/lru-prefix-cache
Open

[runtime, ops] feat: LRU prefix KV cache for cross-request prefill reuse#25
Ziyi-Wang wants to merge 5 commits into
verl-project:mainfrom
Ziyi-Wang:feat/lru-prefix-cache

Conversation

@Ziyi-Wang
Copy link
Copy Markdown

PR description

What does this PR do?

Adds a content-addressed, refcounted, LRU-evicted KV prefix cache so requests sharing a leading prefix (system prompt, chat history, etc.) skip prefill on those tokens. On Qwen2.5-7B with a 1024-token shared system prompt, observed +25–38% rps vs. cache off. PP > 1 transparently falls back to the original allocator (no behavior change there).

Checklist Before Starting

  • Related issues/PRs: none
  • PR title follows [{modules}] {type}: {description}[runtime, ops] feat: ..., validated against scripts/ci/check_pr_title.py.

Test

Qwen2.5-7B, H100×1, 64 ShareGPT requests with a 1024-token shared system prompt:

shared prefix output_len prefix cache off prefix cache on rps gain wall reduction
1024 tok 128 (decode-bound) 8.09 rps, 7.91s 10.11 rps, 6.33s +25% -20%
1024 tok 8 (prefill-bound) 13.97 rps, 4.58s 19.34 rps, 3.31s +38% -28%
  • Hit ratio 79.3% (63/64: first request fills the cache, the rest hit).
  • Bitwise identical: cache-hit and cache-miss runs produce byte-identical token IDs under greedy decode — matches vexact's bitwise-aligned contract.
  • No regression: Qwen3-0.6B ShareGPT without shared prefix is within run-to-run noise of master (12.52 → 12.85 rps).
  • Unit tests:
    • tests/batch_invariant_ops/test_kv_cache_manager.py — 26 new tests covering plan / commit / mark_blocks_filled / OOM rollback / refcount sharing / LRU eviction / hash-only invalidation / disabled-cache fallback. All pass.
    • tests/test_scheduler.py 13/13 pass (one test's setup updated since _activate_request now derives num_computed_tokens from a plan handed in by the caller).

Design & Code Changes

Content-addressed LRU prefix cache. Full KV blocks are chain-hashed (h_i = hash(h_{i-1}, tokens_i)); allocation looks the hash up and either increfs the existing block or takes the least-recently-released one from a free-LRU pool. Block lifetime is governed by refcount: in-use blocks are uneviciable, freed blocks keep their hash and remain hittable until LRU pressure reclaims them. Hashes are committed at prefill completion, not at allocation, which means two concurrent same-prefix requests will both miss the first time (one fills, the rest hit thereafter) — we trade one duplicate miss for not building a "filling" state machine. The manager exposes three primitives — plan_prefix_cache, commit_prefix_plan, mark_blocks_filled — structured so each block hash is computed exactly once per request and threaded through the scheduler, eliminating the triple-hash an earlier draft had.

Integration is driver-local and event-driven. KVCacheManager lives only on the driver; non-driver PP ranks are unchanged. The scheduler calls plan_prefix_cache while sizing chunked-prefill, hands the plan to _activate_request → commit_prefix_plan, and stamps via mark_blocks_filled at prefill completion (idempotent, so no per-request flag). Invalidation is the surgical bit: receive_weights / sleep / load_state_dict call clear_cache_index(), which drops only the hash → block lookup — refcounts and the block pool are untouched, so in-flight requests keep their blocks (no premature free) and future requests cleanly miss instead of hitting stale KV computed under old weights. Observability is a single collective RPC (get_prefix_cache_stats) with a base-class stub so it works regardless of rank topology.

Scope is deliberately tight. ~840 LOC across 11 files; the manager rewrite is the bulk and the rest is thin scheduler / RPC plumbing plus a benchmark switch. PP > 1 falls back to the original allocator transparently — the driver-side index is rank-agnostic, but verifying KV-tensor coherence across ranks deserves its own MR (est. < 50 LOC + tests). Two other deferrals: eager stamping (would let concurrent same-prefix requests share the first miss; needs a "filling" state machine, unnecessary for RL-style warm-then-batch workloads) and a stronger hash (xxhash/blake2b; Python's hash() collision probability at 1024 blocks is ~3×10⁻¹⁴, revisit at 10⁵+ blocks).

Checklist Before Submitting

  • Read the Contribute Guide
  • Applied pre-commit checks (pre-commit run)
  • Added/updated documentation — N/A; no user-facing docs affected
  • Added tests — tests/batch_invariant_ops/test_kv_cache_manager.py (26 new tests) covers the manager API surface; tests/test_scheduler.py still green; bitwise-equivalence + throughput A/B validated via benchmarks/throughput.py.

@Ziyi-Wang Ziyi-Wang force-pushed the feat/lru-prefix-cache branch from ea8560a to ea5f3f4 Compare May 11, 2026 23:58
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a content-addressed prefix cache for the KV cache manager, allowing for token-id sequence reuse across requests to improve performance. The changes include a new hashing mechanism for full blocks, LRU eviction logic, and integration with the scheduler to skip redundant prefill computations. My feedback focuses on a critical race condition between the control channel and the generation loop, potential performance bottlenecks in the scheduling loop due to repeated hashing, and the risks of using Python's built-in hash function for content-addressed caching.

Comment on lines +58 to +81
def receive_weights(self):
"""Receive new model weights.

Active requests are preempted (their KV is stale under the new weights
and decode would read garbage) and the prefix cache index is dropped.
Re-prefill happens automatically on the next schedule().
"""
self.scheduler.reset_for_state_change()
super().receive_weights()

def sleep(self, tag: str = None):
"""Pause memory-saving regions.

The KV cache region is allocated with enable_cpu_backup=False, so its
bytes are dropped on sleep — any active block IDs become invalid. Active
requests are preempted; the prefix cache index is dropped.
"""
self.scheduler.reset_for_state_change()
super().sleep(tag=tag)

def load_state_dict(self, state_dict):
"""In-process weight load (e.g. tests). Same staleness as receive_weights()."""
self.scheduler.reset_for_state_change()
super().load_state_dict(state_dict)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The methods receive_weights, sleep, and load_state_dict introduce a significant race condition. These methods are typically called from the control channel's RPC thread, while the generation loop (_generation_loop) runs in a separate background thread.

Calling self.scheduler.reset_for_state_change() concurrently with self.scheduler.schedule() or self.scheduler.update() can lead to inconsistent state, such as KeyError in the KV cache manager or corrupted request states.

Furthermore, there is a correctness issue: receive_weights clears the cache and preempts requests, but the generation loop continues to run while super().receive_weights() is blocking to receive new weights. This allows the scheduler to re-activate requests and fill the cache with KV tensors computed using the old weights, which will be stale once the weight update completes.

You should implement a synchronization mechanism (e.g., a threading.Lock) to ensure that control commands and the generation loop do not run concurrently, or pause the generation loop during state changes.

Comment thread vexact/core/scheduler.py Outdated
# _activate_request → commit_prefix_plan (no rehash) and later to
# mark_blocks_filled (also no rehash). Eliminates the triple-hash
# that the original peek/allocate/stamp paths had.
block_hashes, num_prefix_hit_blocks = self._kv_cache_manager.plan_prefix_cache(request.input_ids_list)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling plan_prefix_cache inside the scheduling loop for every request in the queue can become a performance bottleneck, especially with long prompts and large queues. Since the block hashes for a given input_ids_list are deterministic, they should be cached on the InferenceRequest object. Only the num_prefix_hit_blocks (which depends on the current cache state) needs to be re-evaluated during scheduling.

prev_hash = self._SEED

for i in range(num_full):
block_hash = hash((prev_hash, tuple(token_ids[i * page_size : (i + 1) * page_size])))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using Python's built-in hash() for content-addressed caching is risky because it is not guaranteed to be collision-resistant and its output is randomized across process restarts (unless PYTHONHASHSEED is fixed). While the collision probability for 1024 blocks is low, it increases quadratically with the number of blocks. For a more robust implementation, especially as the cache size grows, consider using a faster non-cryptographic hash like xxhash or a truncated cryptographic hash.

Ziyi-Wang and others added 3 commits May 12, 2026 09:30
receive_weights / sleep / load_state_dict now call clear_cache_index()
directly instead of routing through scheduler.reset_for_state_change(),
which preempted active requests and reset stats. Drop the preempt-all
helper entirely.

Rationale: the framework (verl) drains all in-flight requests before
triggering a weight update or sleep, so there is nothing to preempt at
that point. clear_cache_index alone is enough to prevent future requests
from hitting KV computed under old weights. Stats become pure lifetime
counters, giving a stable trend line across weight updates instead of a
jittery per-step ratio. Docstrings updated to spell out the drain
assumption.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
KVCacheManager is a scheduler-level resource manager (refcount + LRU +
prefix cache index), not a batch-invariant kernel. Its tests belong
next to test_scheduler.py, not in tests/batch_invariant_ops/ which is
for flex_attention / matmul / FA4 invariance checks.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Block hashes are a pure function of (token_ids, page_size, _SEED), so
recomputing them every scheduling iteration for a requeued request is
wasted work — gets noticeable at long contexts (page_size=16 + 32K
prompts: ~1ms per replan × 100 retries = 100ms).

Split the API:
  - compute_block_hashes(token_ids) → list[int]   (stateless, deterministic)
  - count_prefix_hits(block_hashes) → int         (cheap O(N) dict lookups)

Scheduler caches the hashes on InferenceRequest.prefix_block_hashes after
the first compute; subsequent rounds only re-run count_prefix_hits. The
field already existed for mark_blocks_filled, so no new state is added.
preempt() already clears it (input_ids_list grows on preempt, hashes need
to be recomputed against the new sequence).

plan_prefix_cache (the convenience wrapper) is removed entirely — no
production caller after this change, so keeping it as a test-only
wrapper would just hide that tests aren't exercising the real path.

Addresses Gemini review comment on plan-per-iteration cost.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant