Skip to content

Refactored attention module and move KV Caching outside model#85

Merged
prajwal1210 merged 7 commits into
developfrom
kvfix
Feb 14, 2026
Merged

Refactored attention module and move KV Caching outside model#85
prajwal1210 merged 7 commits into
developfrom
kvfix

Conversation

@prajwal1210

Copy link
Copy Markdown
Collaborator

To support online inference and for better separation of concerns, the model should not be managing the KV cache.

Top Row - Current Branch
Bottom Row - pre GPU KV-caching branch
Screenshot 2025-12-14 at 13 57 00

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors the attention module to move KV cache management outside the model layer for better separation of concerns and to support online inference. The key architectural change extracts KV cache logic from the model into a dedicated KVSlotsManager abstraction that handles both slot allocation and cache policy management.

Key Changes

  • Introduced new KV cache infrastructure with KVSlotsManager, SlotAllocator, and policy-based cache management (ContiguousKVCachePolicy, PagedKVCachePolicy)
  • Refactored C++ KVCacheManager to use a cleaner deque-based API with methods for allocate/extend/free sequences
  • Reorganized attention backend implementations into yalis/attention/backend_impl/ subdirectory
  • Updated model forward signature to accept block_table and token_counter as parameters instead of managing them internally

Reviewed changes

Copilot reviewed 12 out of 15 changed files in this pull request and generated 20 comments.

Show a summary per file
File Description
yalis/external/model.py Removed internal KV cache management logic; model now accepts block_table and token_counter as parameters
yalis/engine.py Added KVSlotsManager initialization and integration; manages slot allocation and cache updates during prefill/decode
yalis/attention/kv_cache/slot_allocator.py New class for allocating stable row IDs with smallest-available policy
yalis/attention/kv_cache/kv_slots_manager.py New unified manager combining slot allocation with cache policy abstraction
yalis/attention/kv_cache/kv_cache_policy.py New policy protocol and implementations for contiguous and paged KV cache strategies
yalis/attention/paged_kv_cache.cpp Refactored C++ implementation to use std::deque and added allocate/extend/free sequence methods
yalis/attention/utils/flex_utils.py New utility file for flex attention mask creation
yalis/attention/utils/flash_utils.py Added actual_seqlens parameter to paged KV cache update kernel
yalis/attention/backend_impl/flash.py Updated to use new utility imports and added actual_seqlens support
yalis/attention/backend_impl/sdpa_and_flex.py Updated registry import path
yalis/attention/backends.py Updated to import from backend_impl subdirectory
yalis/attention/init.py Added actual_seqlens parameter to attention_wrapper
Comments suppressed due to low confidence (1)

yalis/attention/backend_impl/flash.py:161

  • The comment has a grammatical error. It should read "This is a clever way" instead of "This is clever way".

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread yalis/attention/kv_cache/kv_cache_policy.py Outdated
Comment thread yalis/attention/kv_cache/slot_allocator.py Outdated
Comment thread yalis/attention/paged_kv_cache.cpp Outdated
Comment on lines +47 to +56
def update(self, slot_ids: List[int], n_new_tokens: Union[int, List[int]]) -> Optional[torch.Tensor]:
if len(slot_ids) == 0:
return None
index = torch.tensor(slot_ids, dtype=torch.int64, device=self._seq_lens.device)
if isinstance(n_new_tokens, int):
self._seq_lens.index_add_(0, index, torch.full((len(slot_ids),), int(n_new_tokens), dtype=torch.int32, device=self._seq_lens.device))
else:
dt = torch.tensor(n_new_tokens, dtype=torch.int32, device=self._seq_lens.device)
self._seq_lens.index_add_(0, index, dt)
return None

Copilot AI Dec 14, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent return type with Protocol definition. The update method in the Protocol (line 10) specifies -> None as the return type, but the implementation returns Optional[torch.Tensor]. Either update the Protocol or change the implementation to maintain consistency.

Copilot uses AI. Check for mistakes.
Comment thread yalis/attention/kv_cache/kv_slots_manager.py Outdated
Comment thread yalis/attention/kv_cache/kv_slots_manager.py Outdated
Comment thread yalis/external/model.py Outdated
Comment on lines +140 to +141
# TODO: Remove this once we have a way support
# this for speculative decoding without token counter

Copilot AI Dec 14, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment has a grammatical error and missing article. It should read "This is a clever way for now to avoid having to pad" instead of "This is clever way for now to not have to pad".

Suggested change
# TODO: Remove this once we have a way support
# this for speculative decoding without token counter
# TODO: Remove this once we have a way to support this for speculative decoding without token counter

Copilot uses AI. Check for mistakes.
Comment thread yalis/engine.py Outdated
Comment thread yalis/attention/backends.py Outdated
# These imports trigger @register_attention decorators
from . import sdpa_and_flex # noqa: F401
from . import flash # noqa: F401
from yalis.attention.backend_impl.sdpa_and_flex import sdpa_attention # noqa: F401

Copilot AI Dec 14, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'sdpa_attention' is not used.

Copilot uses AI. Check for mistakes.
Comment thread yalis/attention/backends.py
@prajwal1210 prajwal1210 requested a review from aKssup January 30, 2026 15:46
@prajwal1210

Copy link
Copy Markdown
Collaborator Author

There is a 5% TBT performance regression:
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/performance/test_perf_regression.py::test_perf_regression[32-128-8]
FAILED tests/performance/test_perf_regression.py::test_perf_regression[128-128-8]
============= 2 failed, 6 passed, 2 warnings in 276.37s (0:04:36) ==============

This is due to moving from CPU managed KV-cache but we need that

@prajwal1210

Copy link
Copy Markdown
Collaborator Author

Code review

Found 1 issue:

  1. Active printf debug statement left in production code inside a loop in free_sequence(). All surrounding debug statements are properly commented out, but this one is active and will spam stdout every time a sequence is freed -- once per page in the sequence.

for (int64_t i = 0; i < page_count; ++i) {
printf("free_sequence: i=%ld\n", i);
int32_t page = acc[i];
if (page != -1) {
free_pages_.push_back(page);

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

@prajwal1210

Copy link
Copy Markdown
Collaborator Author

Code review (additional lower-confidence findings)

Found 2 additional issues (confidence 75/100 each):

  1. PR reverts the graph break fix from PR Paged kvcache graph break fix #77 ("Paged kvcache graph break fix"). free_pages is changed back from torch::Tensor to std::deque<int32_t>, the torch.ops.yalis library registration is removed, and the tensor-based update_block_table_impl is deleted. This will reintroduce torch.compile graph breaks during paged KV cache management.

block_table_ = torch::full(
{batch_size_, max_num_blocks_per_seq_}, -1,
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));

  1. Two backup files (paged_kv_cache_old.cpp and paged_kv_cache_latest_old.cpp) are added to version control. These are complete copies of previous versions of the C++ file (~600 lines total) with no functional purpose. (CLAUDE.md says "Touch only what you must. Clean up only your own mess." and "Every changed line should trace directly to the user's request.")

#include <tuple>
#include <deque>
#include <iostream>
#include <stdexcept>
#include <torch/extension.h>

#include <deque>
#include <vector>
#include <iostream>
#include <stdexcept>
#include <torch/extension.h>

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

@prajwal1210 prajwal1210 merged commit c82fcee into develop Feb 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants