Refactored attention module and move KV Caching outside model#85
Conversation
There was a problem hiding this comment.
Pull request overview
This pull request refactors the attention module to move KV cache management outside the model layer for better separation of concerns and to support online inference. The key architectural change extracts KV cache logic from the model into a dedicated KVSlotsManager abstraction that handles both slot allocation and cache policy management.
Key Changes
- Introduced new KV cache infrastructure with
KVSlotsManager,SlotAllocator, and policy-based cache management (ContiguousKVCachePolicy,PagedKVCachePolicy) - Refactored C++ KVCacheManager to use a cleaner deque-based API with methods for allocate/extend/free sequences
- Reorganized attention backend implementations into
yalis/attention/backend_impl/subdirectory - Updated model forward signature to accept
block_tableandtoken_counteras parameters instead of managing them internally
Reviewed changes
Copilot reviewed 12 out of 15 changed files in this pull request and generated 20 comments.
Show a summary per file
| File | Description |
|---|---|
| yalis/external/model.py | Removed internal KV cache management logic; model now accepts block_table and token_counter as parameters |
| yalis/engine.py | Added KVSlotsManager initialization and integration; manages slot allocation and cache updates during prefill/decode |
| yalis/attention/kv_cache/slot_allocator.py | New class for allocating stable row IDs with smallest-available policy |
| yalis/attention/kv_cache/kv_slots_manager.py | New unified manager combining slot allocation with cache policy abstraction |
| yalis/attention/kv_cache/kv_cache_policy.py | New policy protocol and implementations for contiguous and paged KV cache strategies |
| yalis/attention/paged_kv_cache.cpp | Refactored C++ implementation to use std::deque and added allocate/extend/free sequence methods |
| yalis/attention/utils/flex_utils.py | New utility file for flex attention mask creation |
| yalis/attention/utils/flash_utils.py | Added actual_seqlens parameter to paged KV cache update kernel |
| yalis/attention/backend_impl/flash.py | Updated to use new utility imports and added actual_seqlens support |
| yalis/attention/backend_impl/sdpa_and_flex.py | Updated registry import path |
| yalis/attention/backends.py | Updated to import from backend_impl subdirectory |
| yalis/attention/init.py | Added actual_seqlens parameter to attention_wrapper |
Comments suppressed due to low confidence (1)
yalis/attention/backend_impl/flash.py:161
- The comment has a grammatical error. It should read "This is a clever way" instead of "This is clever way".
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def update(self, slot_ids: List[int], n_new_tokens: Union[int, List[int]]) -> Optional[torch.Tensor]: | ||
| if len(slot_ids) == 0: | ||
| return None | ||
| index = torch.tensor(slot_ids, dtype=torch.int64, device=self._seq_lens.device) | ||
| if isinstance(n_new_tokens, int): | ||
| self._seq_lens.index_add_(0, index, torch.full((len(slot_ids),), int(n_new_tokens), dtype=torch.int32, device=self._seq_lens.device)) | ||
| else: | ||
| dt = torch.tensor(n_new_tokens, dtype=torch.int32, device=self._seq_lens.device) | ||
| self._seq_lens.index_add_(0, index, dt) | ||
| return None |
There was a problem hiding this comment.
Inconsistent return type with Protocol definition. The update method in the Protocol (line 10) specifies -> None as the return type, but the implementation returns Optional[torch.Tensor]. Either update the Protocol or change the implementation to maintain consistency.
| # TODO: Remove this once we have a way support | ||
| # this for speculative decoding without token counter |
There was a problem hiding this comment.
The comment has a grammatical error and missing article. It should read "This is a clever way for now to avoid having to pad" instead of "This is clever way for now to not have to pad".
| # TODO: Remove this once we have a way support | |
| # this for speculative decoding without token counter | |
| # TODO: Remove this once we have a way to support this for speculative decoding without token counter |
| # These imports trigger @register_attention decorators | ||
| from . import sdpa_and_flex # noqa: F401 | ||
| from . import flash # noqa: F401 | ||
| from yalis.attention.backend_impl.sdpa_and_flex import sdpa_attention # noqa: F401 |
There was a problem hiding this comment.
Import of 'sdpa_attention' is not used.
|
There is a 5% TBT performance regression: This is due to moving from CPU managed KV-cache but we need that |
Code reviewFound 1 issue:
yalis/yalis/attention/paged_kv_cache.cpp Lines 288 to 292 in 17ee838 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
Code review (additional lower-confidence findings)Found 2 additional issues (confidence 75/100 each):
yalis/yalis/attention/paged_kv_cache.cpp Lines 30 to 34 in 17ee838
yalis/yalis/attention/paged_kv_cache_old.cpp Lines 1 to 5 in 17ee838 yalis/yalis/attention/paged_kv_cache_latest_old.cpp Lines 1 to 5 in 17ee838 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
To support online inference and for better separation of concerns, the model should not be managing the KV cache.
Top Row - Current Branch

Bottom Row - pre GPU KV-caching branch