diff --git a/tests/basic_correctness/conftest.py b/tests/basic_correctness/conftest.py index db7c8efb..6302fbe8 100644 --- a/tests/basic_correctness/conftest.py +++ b/tests/basic_correctness/conftest.py @@ -115,13 +115,13 @@ def yalis_engine(model_id, dtype, attn_backend): """Create a standard Yalis LLMEngine.""" model_config = ModelConfig(model_name=model_id, precision=dtype.yalis) inference_config = InferenceConfig( - max_batch_size=4, + max_batch_size=8, max_length_of_generated_sequences=2048, top_p=0.0, temperature=0.0, tp_dims=None, attention_backend=attn_backend.yalis, - use_paged_kv_caching=False, + use_paged_kv_caching=(attn_backend.yalis == "flash"), ) return LLMEngine( model_config=model_config, inference_config=inference_config diff --git a/yalis/attention/__init__.py b/yalis/attention/__init__.py index c2ad8462..890244f1 100644 --- a/yalis/attention/__init__.py +++ b/yalis/attention/__init__.py @@ -4,7 +4,6 @@ from yalis.constants import EnginePhase from .registry import get_attention from .backends import AttentionBackend -from .masking import create_block_mask # noqa: F401 def attention_wrapper( @@ -15,6 +14,7 @@ def attention_wrapper( k_cache: Optional[torch.Tensor] = None, v_cache: Optional[torch.Tensor] = None, cache_seqlens: Optional[torch.Tensor] = None, + actual_seqlens: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, rotary_cos: Optional[torch.Tensor] = None, rotary_sin: Optional[torch.Tensor] = None, @@ -32,6 +32,7 @@ def attention_wrapper( k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, + actual_seqlens=actual_seqlens, block_table=block_table, rotary_cos=rotary_cos, rotary_sin=rotary_sin, diff --git a/yalis/attention/backend_impl/__init__.py b/yalis/attention/backend_impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/yalis/attention/flash.py b/yalis/attention/backend_impl/flash.py similarity index 92% rename from yalis/attention/flash.py rename to yalis/attention/backend_impl/flash.py index 45a2ee7c..6aa77e59 100644 --- a/yalis/attention/flash.py +++ b/yalis/attention/backend_impl/flash.py @@ -1,9 +1,10 @@ -from flash_attn import flash_attn_with_kvcache -import torch from typing import Sequence, Optional -from .registry import register_attention -from .update_kv_cache import update_paged_kv_cache +import torch + +from flash_attn import flash_attn_with_kvcache from flash_attn.ops.triton.rotary import apply_rotary +from yalis.attention.utils.flash_utils import update_paged_kv_cache +from yalis.attention.registry import register_attention from yalis.constants import EnginePhase @@ -93,6 +94,7 @@ def flash_attention( k_cache: Optional[torch.Tensor] = None, v_cache: Optional[torch.Tensor] = None, cache_seqlens: Optional[torch.Tensor] = None, + actual_seqlens: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, rotary_cos: Optional[torch.Tensor] = None, rotary_sin: Optional[torch.Tensor] = None, @@ -145,6 +147,7 @@ def flash_attention( v=v, block_table=block_table, cache_seq_len=cache_seqlens, + actual_seqlens=actual_seqlens, k_cache=k_cache, v_cache=v_cache, ) @@ -153,6 +156,13 @@ def flash_attention( # subsequent layers to update their kv-caches. cache_seqlens = cache_seqlens + T + if phase == EnginePhase.PREFILL: + # This is clever way for now to not have to pad the actual KV-Cache + # and just use the k and v tensors directly in prefill + k_cache = k + v_cache = v + block_table = None + k, v = None, None return torch_compile_compatible_flash_attention( diff --git a/yalis/attention/sdpa_and_flex.py b/yalis/attention/backend_impl/sdpa_and_flex.py similarity index 99% rename from yalis/attention/sdpa_and_flex.py rename to yalis/attention/backend_impl/sdpa_and_flex.py index b03de30f..6cc9f032 100644 --- a/yalis/attention/sdpa_and_flex.py +++ b/yalis/attention/backend_impl/sdpa_and_flex.py @@ -8,7 +8,7 @@ from axonn import axonn as ax from axonn.intra_layer.communication import Drop, Gather -from .registry import register_attention +from yalis.attention.registry import register_attention from yalis.constants import EnginePhase diff --git a/yalis/attention/backends.py b/yalis/attention/backends.py index a825ed33..1960c47e 100644 --- a/yalis/attention/backends.py +++ b/yalis/attention/backends.py @@ -1,6 +1,8 @@ # These imports trigger @register_attention decorators -from . import sdpa_and_flex # noqa: F401 -from . import flash # noqa: F401 +from yalis.attention.backend_impl.sdpa_and_flex import ( # noqa: F401 + sdpa_attention, +) +from yalis.attention.backend_impl.flash import flash_attention # noqa: F401 from enum import Enum diff --git a/yalis/attention/kv_cache/__init__.py b/yalis/attention/kv_cache/__init__.py new file mode 100644 index 00000000..4d391ecc --- /dev/null +++ b/yalis/attention/kv_cache/__init__.py @@ -0,0 +1,15 @@ +from .kv_cache_policy import ( + KVCachePolicy, + ContiguousKVCachePolicy, + PagedKVCachePolicy, +) +from .kv_slots_manager import KVSlotsManager +from .slot_allocator import SlotAllocator + +__all__ = [ + "KVCachePolicy", + "ContiguousKVCachePolicy", + "PagedKVCachePolicy", + "KVSlotsManager", + "SlotAllocator", +] diff --git a/yalis/attention/kv_cache/kv_cache_policy.py b/yalis/attention/kv_cache/kv_cache_policy.py new file mode 100644 index 00000000..319a8527 --- /dev/null +++ b/yalis/attention/kv_cache/kv_cache_policy.py @@ -0,0 +1,180 @@ +from typing import Optional, Protocol, Union, List + +import torch + + +class KVCachePolicy(Protocol): + def allocate( + self, slot_ids: List[int], prompt_lengths: torch.Tensor + ) -> None: ... + + def update( + self, slot_ids: List[int], n_new_tokens: Union[int, List[int]] + ) -> None: ... + + def release(self, slot_ids: List[int]) -> None: ... + + def view(self, slot_ids: List[int]) -> Optional[torch.Tensor]: ... + + def reset(self) -> None: ... + + def lengths(self, slot_ids: List[int]) -> Optional[torch.Tensor]: + """ + Return per-row sequence lengths for provided slot ids. + Returns None if the policy does not track lengths internally. + """ + ... + + +class ContiguousKVCachePolicy: + """ + No-op manager for contiguous KV-cache layout. + Engine can call allocate/update/view to keep a uniform API. + Tracks per-slot sequence lengths locally. + """ + + def __init__( + self, capacity: int, device: Optional[torch.device] = None + ) -> None: + self._seq_lens = torch.zeros( + capacity, dtype=torch.int32, device="cuda" + ) + + def allocate( + self, slot_ids: List[int], prompt_lengths: torch.Tensor + ) -> Optional[torch.Tensor]: + if prompt_lengths.numel() == 0 or len(slot_ids) == 0: + return None + index = torch.tensor( + slot_ids, dtype=torch.int64, device=self._seq_lens.device + ) + self._seq_lens.index_copy_( + 0, + index, + prompt_lengths.to(dtype=torch.int32, device=self._seq_lens.device), + ) + return None + + def update( + self, slot_ids: List[int], n_new_tokens: Union[int, List[int]] + ) -> Optional[torch.Tensor]: + if len(slot_ids) == 0: + return None + index = torch.tensor( + slot_ids, dtype=torch.int64, device=self._seq_lens.device + ) + if isinstance(n_new_tokens, int): + self._seq_lens.index_add_( + 0, + index, + torch.full( + (len(slot_ids),), + int(n_new_tokens), + dtype=torch.int32, + device=self._seq_lens.device, + ), + ) + else: + dt = torch.tensor( + n_new_tokens, dtype=torch.int32, device=self._seq_lens.device + ) + self._seq_lens.index_add_(0, index, dt) + return None + + def release(self, slot_ids: List[int]) -> None: + if len(slot_ids) == 0: + return None + index = torch.tensor( + slot_ids, dtype=torch.int64, device=self._seq_lens.device + ) + self._seq_lens.index_fill_(0, index, 0) + return None + + def view(self, slot_ids: List[int]) -> Optional[torch.Tensor]: + return None + + def reset(self) -> None: + self._seq_lens.zero_() + return None + + def lengths(self, slot_ids: List[int]) -> Optional[torch.Tensor]: + index = torch.tensor( + slot_ids, dtype=torch.int64, device=self._seq_lens.device + ) + return self._seq_lens.index_select(0, index) + + +class PagedKVCachePolicy: + """ + Thin Python wrapper over the C++ paged KV cache allocator (kvcache_manager) + Owned by the Engine. Provides a stable API for page/block table management + """ + + def __init__( + self, + batch_size: int, + max_num_blocks_per_seq: int, + num_blocks: int, + page_block_size: int, + verbose: bool = False, + ) -> None: + # Lazy import + from kvcache_manager import KVCacheManager as _CppKVCacheManager + + self._impl = _CppKVCacheManager( + batch_size, + max_num_blocks_per_seq, + num_blocks, + page_block_size, + ) + self._verbose = verbose + + def block_table(self) -> torch.Tensor: + return self._impl.block_table() + + def allocate( + self, slot_ids: List[int], prompt_lengths: torch.Tensor + ) -> None: + for i, slot in enumerate(slot_ids): + self._impl.allocate_sequence( + int(slot), int(prompt_lengths[i].item()) + ) + + def update( + self, slot_ids: List[int], n_new_tokens: Union[int, List[int]] + ) -> None: + if isinstance(n_new_tokens, int): + for slot in slot_ids: + self._impl.extend_sequence(int(slot), int(n_new_tokens)) + elif isinstance(n_new_tokens, list): + if len(n_new_tokens) != len(slot_ids): + raise ValueError( + "n_new_tokens list must match slot_ids length" + ) + for slot, dt in zip(slot_ids, n_new_tokens): + self._impl.extend_sequence(int(slot), int(dt)) + else: + raise TypeError("n_new_tokens must be int or list[int]") + + def release(self, slot_ids: List[int]) -> None: + for slot in slot_ids: + self._impl.free_sequence(int(slot)) + + def view(self, slot_ids: List[int]) -> torch.Tensor: + bt = self._impl.block_table() + index = torch.tensor(slot_ids, dtype=torch.int64, device=bt.device) + return bt.index_select(0, index) + + def reset(self) -> None: + self._impl.reset() + + def lengths(self, slot_ids: List[int]) -> torch.Tensor: + """ + Return current per-row token counts from the C++ manager (CPU int64), + indexed by the provided slot ids. + """ + all_counts = self._impl.tokens_assigned() + index = torch.tensor( + slot_ids, dtype=torch.int32, device=all_counts.device + ) + return all_counts.index_select(0, index) diff --git a/yalis/attention/kv_cache/kv_slots_manager.py b/yalis/attention/kv_cache/kv_slots_manager.py new file mode 100644 index 00000000..1f28bb5e --- /dev/null +++ b/yalis/attention/kv_cache/kv_slots_manager.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple, Union + +import torch + +from yalis.attention.kv_cache.kv_cache_policy import ( + ContiguousKVCachePolicy, + PagedKVCachePolicy, +) +from yalis.attention.kv_cache.slot_allocator import SlotAllocator + + +class KVSlotsManager: + """ + Unified row/slot + KV-cache manager. + - Owns row_id assignment (via SlotAllocator) + - Hides paged vs contiguous cache differences + - Minimal Engine-facing API: + * allocate(req_id, prompt_lengths) -> rows, block_table_rows|None + * update(delta_tokens, rows) -> block_table_rows|None + * free(req_id) -> freed_rows + * view(rows) -> block_table_rows|None + * reset() + """ + + def __init__( + self, + capacity: int, + paged: bool, + max_num_blocks_per_seq: Optional[int] = None, + num_blocks: Optional[int] = None, + page_block_size: Optional[int] = None, + verbose: bool = False, + ) -> None: + self.slot_allocator = SlotAllocator(capacity) + if paged: + assert ( + max_num_blocks_per_seq is not None + and num_blocks is not None + and page_block_size is not None + ), "paged requires max_num_blocks_per_seq, num_blocks, page_block_size" # noqa: E501 + self.cache_policy = PagedKVCachePolicy( + batch_size=capacity, + max_num_blocks_per_seq=max_num_blocks_per_seq, + num_blocks=num_blocks, + page_block_size=page_block_size, + verbose=verbose, + ) + self.paged = True + else: + self.cache_policy = ContiguousKVCachePolicy(capacity=capacity) + self.paged = False + + def allocate( + self, + req_ids: List[str], + prompt_lengths: torch.Tensor, + ) -> Tuple[List[int], torch.Tensor]: + """ + Allocate one row per sequence in prompt_lengths for request req_ids. + If there are no free slots, return an empty list. + Returns list of slot_ids for the allocated rows. + """ + if prompt_lengths.dim() != 1: + raise ValueError("prompt_lengths must be 1D [B]") + B = int(prompt_lengths.size(0)) + assert ( + len(req_ids) == B + ), "req_ids and prompt_lengths must have the same length" + n = min(B, self.slot_allocator.free_count()) + rows: List[int] = [] + for i in range(n): + slot_id = self.slot_allocator.allocate(req_ids[i]) + rows.append(slot_id) + if len(rows) > 0: + self.cache_policy.allocate(rows, prompt_lengths[: len(rows)]) + return rows + + def update( + self, + rows: Union[List[int], torch.Tensor], # [B] slot_ids + delta_tokens: Union[int, torch.Tensor], + ) -> None: + """ + Update KV-cache page allocations for the provided rows by delta_tokens. + Returns block_table slice for rows when paged; None for contiguous. + """ + if isinstance(rows, torch.Tensor): + rows = rows.tolist() + if isinstance(delta_tokens, int): + self.cache_policy.update(rows, delta_tokens) + elif isinstance(delta_tokens, torch.Tensor): + self.cache_policy.update(rows, delta_tokens.tolist()) + else: + raise TypeError("delta_tokens must be int or 1D torch.Tensor") + + def allocate_for_rows( + self, + rows: Union[List[int], torch.Tensor], + prompt_lengths: Union[List[int], torch.Tensor], + ) -> None: + """ + Allocate KV-cache pages for already-assigned rows with given prompt + lengths. Does not change row ownership; only manages page assignment. + Used for testing. + """ + if isinstance(rows, torch.Tensor): + slot_ids = rows.tolist() + else: + slot_ids = rows # type: ignore[assignment] + if isinstance(prompt_lengths, list): + lengths_t = torch.tensor(prompt_lengths, dtype=torch.long) + elif isinstance(prompt_lengths, torch.Tensor): + lengths_t = prompt_lengths + else: + raise TypeError("prompt_lengths must be List[int] or torch.Tensor") + self.cache_policy.allocate(slot_ids, lengths_t) + + def free(self, req_id: str) -> int | None: + """ + Free all rows owned by req_id and release paged allocations if any. + Returns list of freed row ids. + """ + slot_id = self.slot_allocator.free(req_id) + if slot_id is not None: + self.cache_policy.release([slot_id]) + return slot_id + + def view( + self, ids: List[int] | List[str], is_request_ids: bool = False + ) -> Optional[torch.Tensor]: + """ + Return block_table slice for rows when paged; None for contiguous. + """ + if is_request_ids: + slot_ids = [ + self.slot_allocator.get_slot_id(req_id) for req_id in ids + ] + else: + slot_ids = ids + return self.cache_policy.view(slot_ids) + + def reset(self) -> None: + self.slot_allocator.reset() + self.cache_policy.reset() + + def lengths(self, rows: Union[List[int], torch.Tensor]) -> torch.Tensor: + """ + Return current per-row sequence lengths. + Delegates to policy. + """ + if isinstance(rows, list): + rows_t = torch.tensor(rows, dtype=torch.int64, device="cuda") + else: + rows_t = rows.to(torch.int64, device="cuda") + return self.cache_policy.lengths(rows_t.tolist()) diff --git a/yalis/attention/kv_cache/slot_allocator.py b/yalis/attention/kv_cache/slot_allocator.py new file mode 100644 index 00000000..14b15e4a --- /dev/null +++ b/yalis/attention/kv_cache/slot_allocator.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import Dict, Set + +from yalis.serving.logger import get_logger + +logger = get_logger("slot_allocator") + + +class SlotAllocator: + """ + Sequence row allocator for online serving. + - Hands out stable row ids in [0..capacity-1] + - Smallest-available policy keeps active rows dense + """ + + def __init__(self, capacity: int) -> None: + if capacity <= 0: + raise ValueError("capacity must be > 0") + + logger.info(f"SlotAllocator capacity: {capacity}") + self._capacity: int = capacity + self._free: Set[int] = set(range(capacity)) + self._req_to_slot_id: Dict[str, int] = {} + + @property + def capacity(self) -> int: + return self._capacity + + def free_count(self) -> int: + return len(self._free) + + def allocate(self, req_id: str) -> int: + if self.free_count() == 0: + raise RuntimeError( + "insufficient free slots. Call free_count() first" + "to check if there are any free slots." + ) + logger.debug( + f"SlotAllocator allocate req_id: {req_id} " + f"free_count: {self.free_count()}, _free: {self._free}" + ) + slot_id = min(self._free) + self._free.remove(slot_id) + self._req_to_slot_id[req_id] = slot_id + return slot_id + + def free(self, req_id: str) -> int | None: + slot_id = self._req_to_slot_id.pop(req_id, None) + if slot_id is not None: + self._free.add(slot_id) + return slot_id + + def get_slot_id(self, req_id: str) -> int: + return self._req_to_slot_id[req_id] + + def reset(self) -> None: + self._free = set(range(self._capacity)) + self._req_to_slot_id.clear() diff --git a/yalis/attention/paged_kv_cache.cpp b/yalis/attention/paged_kv_cache.cpp index 562f8b42..ddfb39d4 100644 --- a/yalis/attention/paged_kv_cache.cpp +++ b/yalis/attention/paged_kv_cache.cpp @@ -1,18 +1,14 @@ -#include #include -#include +#include #include #include #include -#include -#include -#include -#include // KVCacheManager manages the assignment of global pages (blocks) for each // sequence in the batch. It maintains a block table (a tensor of shape // [batch_size, max_num_blocks_per_seq]) and a tensor tokens_assigned_ to track -// how many tokens have been pushed into the KV cache for each sequence. +// how many tokens have been pushed into the KV cache for each sequence. A FIFO +// queue of free pages is maintained via a std::deque. class KVCacheManager { public: // Constructor. @@ -30,6 +26,7 @@ class KVCacheManager { // Initialize the block table tensor of shape [batch_size, // max_num_blocks_per_seq] with all elements set to -1 (indicating "no page // assigned"), as int32. + block_table_ = torch::full( {batch_size_, max_num_blocks_per_seq_}, -1, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); @@ -37,87 +34,134 @@ class KVCacheManager { // Initialize tokens_assigned_ which tracks how many tokens are currently // in the KV cache for each sequence. Initially, every sequence has 0 // tokens. - tokens_assigned_ = torch::zeros({batch_size_}, torch::TensorOptions() - .dtype(torch::kInt64).device(torch::kCUDA)); + // TODO: This should be in pinned memory + tokens_assigned_ = torch::zeros({batch_size_}, torch::TensorOptions().dtype( + torch::kInt32).device(torch::kCUDA)); // Initialize the FIFO queue of free pages with indices from 0 to num_blocks // - 1. - this->free_pages = torch::arange( - num_blocks_, - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA) - ); - // FIFO counter of free pages - this->next_free = torch::zeros( - {1}, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA) - ); + for (int32_t i = 0; i < static_cast(num_blocks_); i++) { + free_pages_.push_back(i); + } + } + + // reset: + // Resets the block table and tokens_assigned_ and refills the free_pages_ + // FIFO queue. + void reset() { + block_table_.fill_(-1); + tokens_assigned_.fill_(0); + free_pages_.clear(); + for (int32_t i = 0; i < static_cast(num_blocks_); i++) { + free_pages_.push_back(i); + } } - // accessors to expose to python for avoiding graph breaks - torch::Tensor tokens_assigned_tensor() const { return tokens_assigned_; } - torch::Tensor free_pages_tensor() const { return free_pages; } - torch::Tensor next_page_tensor() const { return next_free; } + // Accessor for the block table (for debugging or introspection). torch::Tensor block_table() const { return block_table_; } - void force_update_tokens_assigned(const torch::Tensor &new_token_counts) { - // Ensure new_token_counts tensor has the correct batch size. - if (new_token_counts.sizes()[0] != batch_size_) { - throw std::runtime_error("Input tensor size does not match batch_size."); + // Accessor for tokens_assigned (current per-row token counts). + torch::Tensor tokens_assigned() const { return tokens_assigned_; } + + // allocate_sequence: + // Reserve pages for a sequence row based on initial_tokens and populate the + // corresponding row of block_table_. Returns the assigned page ids. + std::vector allocate_sequence(int64_t row_id, int64_t initial_tokens) { + if (row_id < 0 || row_id >= batch_size_) { + throw std::runtime_error("allocate_sequence: row_id out of range"); + } + if (initial_tokens < 0) { + throw std::runtime_error("allocate_sequence: initial_tokens must be >= 0"); + } + // If already has tokens, refuse (caller should use extend). + int64_t current_tokens = tokens_assigned_[row_id].item(); + if (current_tokens != 0) { + throw std::runtime_error("allocate_sequence: row already initialized; use extend_sequence"); + } + int64_t page_count = + (initial_tokens + page_block_size_ - 1) / page_block_size_; + if (page_count > max_num_blocks_per_seq_) { + throw std::runtime_error("allocate_sequence: exceeds max_num_blocks_per_seq"); + } + if (static_cast(free_pages_.size()) < page_count) { + throw std::runtime_error("allocate_sequence: insufficient free pages"); + } + std::vector assigned; + assigned.reserve(static_cast(page_count)); + for (int64_t i = 0; i < page_count; ++i) { + int32_t p = free_pages_.front(); + free_pages_.pop_front(); + block_table_.index_put_({row_id, i}, p); + assigned.push_back(p); } - // Copy the input tensor into tokens_assigned_. - tokens_assigned_.copy_(new_token_counts); + tokens_assigned_.index_put_({row_id}, initial_tokens); + return assigned; } - // get_pages_for_sequence: - // Returns a tensor containing the pages assigned to the sequence with index - // seq_idx. The valid page count is computed as ceil(tokens_assigned_[seq_idx] - // / page_block_size_). - torch::Tensor get_pages_for_sequence(int64_t seq_idx) { - int64_t token_count = tokens_assigned_[seq_idx].item(); - int64_t page_count = - (token_count + page_block_size_ - 1) / page_block_size_; - // First select the row for seq_idx, then narrow that row to the first - // page_count entries. - auto pages = block_table_.select(0, seq_idx).narrow(0, 0, page_count); - return pages; + // extend_sequence: + // Add n_new_tokens for a sequence row; assigns additional pages if crossing + // page boundaries. Returns only the newly assigned page ids (if any). + std::vector extend_sequence(int64_t row_id, int64_t n_new_tokens) { + if (row_id < 0 || row_id >= batch_size_) { + throw std::runtime_error("extend_sequence: row_id out of range"); + } + if (n_new_tokens <= 0) { + return {}; // nothing to do + } + int64_t current_tokens = tokens_assigned_[row_id].item(); + int64_t new_total_tokens = current_tokens + n_new_tokens; + int64_t old_pages = + (current_tokens + page_block_size_ - 1) / page_block_size_; + int64_t new_pages = + (new_total_tokens + page_block_size_ - 1) / page_block_size_; + if (new_pages > max_num_blocks_per_seq_) { + throw std::runtime_error("extend_sequence: exceeds max_num_blocks_per_seq"); + } + int64_t delta = new_pages - old_pages; + std::vector newly_assigned; + if (delta > 0) { + if (static_cast(free_pages_.size()) < delta) { + throw std::runtime_error("extend_sequence: insufficient free pages"); + } + newly_assigned.reserve(static_cast(delta)); + for (int64_t i = old_pages; i < new_pages; ++i) { + int32_t p = free_pages_.front(); + free_pages_.pop_front(); + block_table_.index_put_({row_id, i}, p); + newly_assigned.push_back(p); + } + } + tokens_assigned_.index_put_({row_id}, new_total_tokens); + return newly_assigned; } - // release_sequence_pages: - // Releases (frees) all pages assigned to the sequence at index seq_idx. - // The freed pages are pushed back into the free_pages_ FIFO queue. - // The corresponding row in the block table is reset to -1, and - // tokens_assigned_ is reset to 0. - /* - void release_sequence_pages(int64_t seq_idx) { - int64_t token_count = tokens_assigned_[seq_idx].item(); + // free_sequence: + // Frees all pages assigned to row_id and clears row state. Returns the freed + // page ids (in FIFO order of the row, not the global queue). + std::vector free_sequence(int64_t row_id) { + if (row_id < 0 || row_id >= batch_size_) { + throw std::runtime_error("free_sequence: row_id out of range"); + } + int64_t token_count = tokens_assigned_[row_id].item(); + if (token_count == 0) { + return {}; + } int64_t page_count = (token_count + page_block_size_ - 1) / page_block_size_; - // First select the row corresponding to seq_idx, then narrow to the first - // page_count elements. - auto row = block_table_.select(0, seq_idx).narrow(0, 0, page_count); - auto row_accessor = row.accessor(); - for (int64_t i = 0; i < page_count; i++) { - if (row_accessor[i] != -1) { - free_pages_.push_back(row_accessor[i]); + std::vector freed; + freed.reserve(static_cast(page_count)); + auto row = block_table_.select(0, row_id).narrow(0, 0, page_count).to(torch::kCPU); + auto acc = row.accessor(); + for (int64_t i = 0; i < page_count; ++i) { + int32_t page = acc[i]; + if (page != -1) { + free_pages_.push_back(page); + freed.push_back(page); } } - // Reset the row for the sequence by filling it with -1. - block_table_.select(0, seq_idx).fill_(-1); - tokens_assigned_.index_put_({seq_idx}, 0); - } - */ - - // reset: - // Resets the block table and tokens_assigned_ and refills the free_pages - // FIFO queue. - void reset() { - c10::cuda::CUDAGuard guard(block_table_.device()); - // same storage, just reinitialize contents - block_table_.fill_(-1); // [B, M] int32 - tokens_assigned_.zero_(); // [B] int64 - next_free.zero_(); // [1] int64 - // refill free_pages as [0, 1, ..., num_blocks_-1] - auto opts = free_pages.options().dtype(torch::kInt32); - free_pages.copy_(torch::arange(num_blocks_, opts)); + block_table_.select(0, row_id).fill_(-1); + tokens_assigned_.index_put_({row_id}, 0); + return freed; } private: @@ -135,163 +179,17 @@ class KVCacheManager { torch::Tensor tokens_assigned_; // FIFO queue of free page indices (as int32). - // std::deque free_pages_; - torch::Tensor free_pages; - torch::Tensor next_free; + std::deque free_pages_; }; -static void force_update_tokens_assigned_impl( - torch::Tensor tokens_assigned, const torch::Tensor &new_counts -) { - TORCH_CHECK(tokens_assigned.size(0) == new_counts.size(0), "batch mismatch"); - tokens_assigned.copy_(new_counts); -} - -static inline at::Tensor ceil_div_tensor(const at::Tensor &x, int64_t d) { - return at::floor_divide(x + (d - 1), d); -} - -// update_block_table_impl: -// seq_lengths: A tensor of shape [batch_size] containing the number of new -// tokens -// to be pushed into the KV cache for each sequence. -// For each sequence, this method adds the incoming tokens to the current -// token count, computes the new required page count, and if new pages are -// needed, assigns additional pages from the free_pages FIFO queue. -static void update_block_table_impl( - const at::Tensor &block_table, // int32/int64, [B, M], contiguous - const at::Tensor &tokens_assigned, // int64, [B] - const at::Tensor &next_page, // int64, [1] - const at::Tensor &free_pages, // int32, [N_pages] - const at::Tensor &seq_lengths, // int32/int64, [B] - int64_t page_block_size, - int64_t max_blocks_per_seq -) { - c10::cuda::CUDAGuard guard(block_table.device()); - - TORCH_CHECK(block_table.dim() == 2, "block_table must be [B, M]"); - TORCH_CHECK(tokens_assigned.dim() == 1, "tokens_assigned must be [B]"); - TORCH_CHECK(seq_lengths.dim() == 1, "seq_lengths must be [B]"); - TORCH_CHECK(block_table.is_contiguous(), "block_table must be contiguous"); - TORCH_CHECK(tokens_assigned.scalar_type() == at::kLong, "tokens_assigned int64"); - TORCH_CHECK(block_table.scalar_type() == at::kInt || block_table.scalar_type()==at::kLong, "block_table must be int32 or int64"); - - const int64_t B = block_table.size(0); - const int64_t M = block_table.size(1); - const int64_t N_pages = free_pages.size(0); - - auto dev = block_table.device(); - auto long_opts = at::TensorOptions().device(dev).dtype(at::kLong); - auto table_dtype = block_table.scalar_type(); - - const int64_t K = next_page.numel(); - TORCH_CHECK(K == 1 && next_page.numel() == 1, "next_page must have numel 1; got ", K); - auto next_per = next_page.view({1}).expand({B}); - - // same as original impl except vectorized - auto inc_tokens = seq_lengths.to(at::kLong); // [B] - auto old_tokens = tokens_assigned; // [B] - auto new_tokens = old_tokens + inc_tokens; // [B] - auto old_pages = ceil_div_tensor(old_tokens, page_block_size); // [B] - auto new_pages = ceil_div_tensor(new_tokens, page_block_size); // [B] - // how many additional pages to assign - auto delta = (new_pages - old_pages).clamp_min(0); - -#ifndef NDEBUG // .item calls will break graphs, so not in release - TORCH_CHECK( - (new_pages <= at::full({B}, (long)M, long_opts)).all().item(), - "Exceeded maximum number of blocks per sequence." - ); - // just for safety check - auto need_total = delta.sum(); - const long next_host = next_page.item(); - TORCH_CHECK( - next_host + need_total.item() <= N_pages, - "No free pages available in the global KV cache." - ); -#endif - - // offsets into the FIFO window, which pages each seq should take - auto csum = at::cumsum(delta, 0); // [B] - auto start_excl = csum - delta; // [B] - auto base_per = next_per + start_excl; // [B] - - // Column mask over [0..M-1] - auto cols = at::arange(M, long_opts); // [M] - auto cols2d = cols.view({1, M}).expand({B, M}); // [B, M] - auto old2d = old_pages.view({B, 1}).expand({B, M}); // [B, M] - auto del2d = delta.view({B, 1}).expand({B, M}); // [B, M] - auto mask = (cols2d >= old2d) & (cols2d < (old2d + del2d)); // [B, M], bool - auto mask_i = mask.to(at::kLong); // [B, M] - - // Row-local 0 .. (delta-1) counters - auto k_in_row = (at::cumsum(mask_i, 1) - 1) * mask_i; // [B, M] - - // Compute page indices, clamp to stay in-bounds (avoid OOB during capture) - auto base2d = base_per.view({B,1}).expand({B,M}); // [B, M] - auto page_idx = (base2d + k_in_row).reshape({B * M}).to(at::kLong); - - // Gather and blend (no variable-length selects) - auto gathered = free_pages.index_select(0, page_idx).to(table_dtype) - .view({B, M}); // [B, M] - auto blended = at::where(mask, gathered, block_table); // [B, M] - block_table.copy_(blended); - - // Update counters in-place - tokens_assigned.copy_(new_tokens); // [B] - // advance the single global head by total granted pages - next_page.add_(delta.sum()); -} - -static void force_update_tokens_assigned_meta( - torch::Tensor tokens_assigned, - const torch::Tensor &token_counter -) {} - -static void update_block_table_meta( - const torch::Tensor &block_table, - const torch::Tensor &tokens_assigned, - const torch::Tensor &next_page, - const torch::Tensor &free_pages, - const torch::Tensor &seq_lengths, - int64_t /*page_block_size*/, - int64_t /*max_blocks_per_seq*/ -) { - return; -} - -TORCH_LIBRARY(yalis, m) { - m.def("force_update_tokens_assigned_(Tensor(a!) tokens_assigned, Tensor new_counts) -> ()"); - m.def( - "update_block_table_(Tensor(a!) block_table, " - "Tensor(b!) tokens_assigned, " - "Tensor(c!) next_page, " - "Tensor free_pages, " - "Tensor seq_lengths, " - "int page_block_size, " - "int max_blocks_per_seq) " - "-> ()" - ); -} - -TORCH_LIBRARY_IMPL(yalis, CUDA, m) { - m.impl("force_update_tokens_assigned_", force_update_tokens_assigned_impl); - m.impl("update_block_table_", update_block_table_impl); -} - -TORCH_LIBRARY_IMPL(yalis, Meta, m) { - m.impl("force_update_tokens_assigned_", force_update_tokens_assigned_meta); - m.impl("update_block_table_", update_block_table_meta); -} - // Expose the KVCacheManager as a custom class via PyBind11. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "KVCacheManager") .def(py::init()) - .def("get_pages_for_sequence", &KVCacheManager::get_pages_for_sequence) + .def("allocate_sequence", &KVCacheManager::allocate_sequence) + .def("extend_sequence", &KVCacheManager::extend_sequence) + .def("free_sequence", &KVCacheManager::free_sequence) .def("reset", &KVCacheManager::reset) .def("block_table", &KVCacheManager::block_table) - .def("tokens_assigned_tensor", &KVCacheManager::tokens_assigned_tensor) - .def("free_pages_tensor", &KVCacheManager::free_pages_tensor) - .def("next_page_tensor", &KVCacheManager::next_page_tensor); + .def("tokens_assigned", &KVCacheManager::tokens_assigned); } diff --git a/yalis/attention/utils/__init__.py b/yalis/attention/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/yalis/attention/update_kv_cache.py b/yalis/attention/utils/flash_utils.py similarity index 94% rename from yalis/attention/update_kv_cache.py rename to yalis/attention/utils/flash_utils.py index 6b917775..0abe640a 100644 --- a/yalis/attention/update_kv_cache.py +++ b/yalis/attention/utils/flash_utils.py @@ -9,6 +9,7 @@ def update_paged_kv_cache_kernel( v_ptr, block_table_ptr, cache_seq_len_ptr, + actual_seqlens_ptr, cache_k_ptr, cache_v_ptr, B, @@ -32,6 +33,11 @@ def update_paged_kv_cache_kernel( # Offset in kv cache cache_offset = tl.load(cache_seq_len_ptr + b) token_offset = cache_offset + s + if actual_seqlens_ptr is not None: + actual_seqlen = tl.load(actual_seqlens_ptr + b) + if token_offset >= actual_seqlen: + return + page_id = token_offset // page_block_size offset_in_block = token_offset % page_block_size block_id = tl.load(block_table_ptr + b * max_pages_per_seq + page_id) @@ -114,10 +120,10 @@ def update_paged_kv_cache( v: torch.Tensor, block_table: torch.Tensor, cache_seq_len: torch.Tensor, + actual_seqlens: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, ): - B, S, H, D = k.shape BLOCK_D = D BLOCK_H = min(1024 // BLOCK_D, H) @@ -132,6 +138,7 @@ def update_paged_kv_cache( v, block_table, cache_seq_len, + actual_seqlens, k_cache, v_cache, B, diff --git a/yalis/attention/masking.py b/yalis/attention/utils/flex_utils.py similarity index 100% rename from yalis/attention/masking.py rename to yalis/attention/utils/flex_utils.py diff --git a/yalis/engine.py b/yalis/engine.py index 97201cbc..4e64a291 100644 --- a/yalis/engine.py +++ b/yalis/engine.py @@ -14,12 +14,16 @@ from transformers import AutoTokenizer from torch.nn.attention import SDPBackend, sdpa_kernel from .constants import EnginePhase +from yalis.attention.kv_cache import KVSlotsManager import time import gc from .timers import Timers import os +# TODO: these should be dynamically set during engine initialization +NUM_BLOCKS, PAGE_BLOCK_SIZE = 512, 256 + # These flags are taken from the following URL - # https://github.com/pytorch/pytorch/blob/347f96061f1cff603983b9be19ec92b374329a5b/benchmarks/gpt_fast/generate.py#L19 @@ -67,6 +71,8 @@ def prefill( top_p=1.0, get_logits=False, phase: EnginePhase = EnginePhase.PREFILL, + block_table: torch.Tensor = None, + token_counter: torch.Tensor = None, ): """ Prefill function for generating the first token. @@ -81,9 +87,13 @@ def prefill( logits: (Optional) The raw logits from the model. """ - logits = model(tokens, phase, unpadded_prompt_lengths)["logits"].to( - torch.float32 - ) + logits = model( + tokens, + phase, + unpadded_prompt_lengths, + block_table=block_table, + token_counter=token_counter, + )["logits"].to(torch.float32) logits = logits[torch.arange(logits.size(0)), unpadded_prompt_lengths - 1] token_id = sample( logits=logits, temperature=temperature, top_k=top_k, top_p=top_p @@ -105,6 +115,8 @@ def generate( top_p=1.0, get_logits=False, phase: EnginePhase = EnginePhase.DECODE_SINGLE, + block_table: torch.Tensor = None, + token_counter: torch.Tensor = None, ): """ Generate function for producing the next token(s). @@ -119,7 +131,9 @@ def generate( token_id: The next predicted token. logits: (Optional) The raw logits from the model. """ - logits = model(tokens, phase)["logits"].to(torch.float32) + logits = model( + tokens, phase, block_table=block_table, token_counter=token_counter + )["logits"].to(torch.float32) token_id = sample( logits=logits[:, -1], temperature=temperature, top_k=top_k, top_p=top_p ) @@ -182,6 +196,16 @@ def __init__( self.model_config = model_config self.inference_config = inference_config + # TODO: Move to a separate Python class + # with better memory management and better API + self.kv_slots_manager = KVSlotsManager( + inference_config.max_batch_size, + inference_config.use_paged_kv_caching, + 16384 // PAGE_BLOCK_SIZE, # ToDo: set these dynamically + NUM_BLOCKS, + PAGE_BLOCK_SIZE, + ) + # return extra memory to CUDA. Can prevent NCCL init OOMs torch.cuda.empty_cache() gc.collect() @@ -268,6 +292,8 @@ def _reset_kv_cache(self, model, max_batch_size): device=self.device, dtype=self.dtype, ) + self.kv_slots_manager.reset() + if self.inference_config.symmetric_allreduce_strategy is not None: model.create_symmetric_memory_pool( max_batch_size=max_batch_size, @@ -314,10 +340,16 @@ def _tokenize_prompts(self, prompts): ) return prompt_tokens, prompt_sequence_lengths - def _validate_sequence_lengths( + def _validate_generation_inputs( self, prompt_sequence_lengths, tokens_to_generate ): - """Validate and adjust sequence lengths if necessary.""" + """Validate batch/sequence limits and adjust decode length.""" + batch_size = int(prompt_sequence_lengths.size(0)) + if batch_size > self.inference_config.max_batch_size: + raise ValueError( + f"Batch size ({batch_size}) exceeds configured max_batch_size " + f"({self.inference_config.max_batch_size})." + ) if prompt_sequence_lengths.max() > self.model.max_seq_length: raise ValueError( f"Prompt sequence length ({prompt_sequence_lengths.max()})" @@ -380,7 +412,7 @@ def generate( prompt_tokens, prompt_sequence_lengths = self._tokenize_prompts( prompts ) - tokens_to_generate = self._validate_sequence_lengths( + tokens_to_generate = self._validate_generation_inputs( prompt_sequence_lengths, tokens_to_generate ) timers.stop("tokenize") @@ -401,8 +433,8 @@ def generate( # Start timing the operations timers.start("generate") self.model.token_counter.zero_() - if self.inference_config.use_paged_kv_caching: - self.model.kv_cache_manager.reset() + self.kv_slots_manager.reset() + with torch.inference_mode(), torch.autocast( self.device, dtype=self.dtype, cache_enabled=False ): @@ -410,13 +442,29 @@ def generate( self.device ) # Move prompt tokens to the device - prompt_sequence_lengths = prompt_sequence_lengths.to(self.device) + prompt_sequence_lengths = prompt_sequence_lengths.to( + self.device + ).to(torch.int32) + B = current_input_to_model.shape[0] + req_ids = [f"req_{i}" for i in range(B)] + token_counter = torch.zeros( + B, dtype=torch.int32, device=self.device + ) + slot_ids = self.kv_slots_manager.allocate( + req_ids, prompt_sequence_lengths + ) + if len(slot_ids) != B: + raise RuntimeError( + f"Allocated {len(slot_ids)} KV slots for batch size {B}. " + "Expected full-batch allocation." + ) for step in range(tokens_to_generate): timer_key = None if step == 0: # Prefill step timer_key = "prefill" timers.start(timer_key) nvtx_range_push("Prefill") + next_token, logits = prefill( self.model, current_input_to_model, @@ -425,6 +473,8 @@ def generate( top_k=self.inference_config.top_k, top_p=self.inference_config.top_p, get_logits=get_logits, + block_table=self.kv_slots_manager.view(slot_ids), + token_counter=token_counter, ) # Call prefill function current_input_to_model = next_token.clone() @@ -433,6 +483,10 @@ def generate( timer_key = "decode" timers.start(timer_key) nvtx_range_push("Decode") + + token_counter = self.kv_slots_manager.lengths(slot_ids) + self.kv_slots_manager.update(slot_ids, 1) + with sdpa_kernel(SDPBackend.MATH): next_token, logits = generate( self.model, @@ -441,6 +495,8 @@ def generate( top_k=self.inference_config.top_k, top_p=self.inference_config.top_p, get_logits=get_logits, + block_table=self.kv_slots_manager.view(slot_ids), + token_counter=token_counter, ) # Call generate function current_input_to_model.copy_( @@ -589,7 +645,7 @@ def generate_speculative( prompt_tokens, prompt_sequence_lengths = self._tokenize_prompts( input_tokens ) - tokens_to_generate = self._validate_sequence_lengths( + tokens_to_generate = self._validate_generation_inputs( prompt_sequence_lengths, tokens_to_generate ) timers.stop("tokenize") diff --git a/yalis/external/model.py b/yalis/external/model.py index c9c78cc2..bed2d7a6 100644 --- a/yalis/external/model.py +++ b/yalis/external/model.py @@ -22,10 +22,15 @@ from yalis.external.config import Config from yalis.tensor_parallel import TPLinear, TPMoE from yalis.constants import EnginePhase -from kvcache_manager import KVCacheManager -from yalis.attention.flash import flash_apply_rotary as apply_rotary + +# TODO: These need to be abstracted away from the model from yalis.attention.backends import AttentionBackend -from yalis.attention.masking import create_causal_block_mask_for_flex_attention +from yalis.attention.backend_impl.flash import ( + flash_apply_rotary as apply_rotary, +) +from yalis.attention.utils.flex_utils import ( + create_causal_block_mask_for_flex_attention, +) # TODO: these should be dynamically set during engine initialization NUM_BLOCKS, PAGE_BLOCK_SIZE = 512, 256 @@ -119,6 +124,8 @@ def forward( input_ids: torch.Tensor, phase: EnginePhase, actual_sequence_lengths: torch.Tensor = None, + block_table: torch.Tensor = None, + token_counter: torch.Tensor = None, ) -> torch.Tensor: idx = input_ids T = idx.size(1) @@ -127,32 +134,6 @@ def forward( f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." # noqa: E501 ) - # Update block table - # assign new pages to each sequence if needed to store new keys/values - # actual storage will be done by the flash attention kernel. - # this is just assigning pages to each sequence - if self.config.use_paged_kv_caching: - # create pages for T new tokens if needed. - # Note that T includes padding tokens in prefill. - # we will readjust the token counters of the block table - # at the end to exclude padded tokens. - B = input_ids.shape[0] - seq_lengths = torch.full( - (B,), - T, - dtype=torch.int64, - device=self.kvcache_block_table.device, - ) - torch.ops.yalis.update_block_table_( - self.kvcache_block_table[:B], - self.tokens_assigned[:B], - self.kvcache_next_page, - self.kvcache_free_pages, - seq_lengths, - PAGE_BLOCK_SIZE, - 16384 // PAGE_BLOCK_SIZE, - ) - x = self.transformer.wte( idx ) # token embeddings of shape (b, t, n_embd) @@ -160,6 +141,12 @@ def forward( x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) if self.config.tensor_parallel: x = Drop.apply(x, ax.comm_handle.inner_intra_layer_parallel_group) + # TODO: Remove this once we have a way support + # this for speculative decoding without token counter + update_token_counter = False + if token_counter is None: + update_token_counter = True + token_counter = self.token_counter # flash attention wants the rope cache to be # in the same dtype as the query @@ -168,19 +155,11 @@ def forward( self.cos = self.cos.to(x.dtype) self.sin = self.sin.to(x.dtype) - # Block table is not sliced and expected that - # the attention backend will handle the slicing. - block_table = ( - self.kvcache_block_table - if self.config.use_paged_kv_caching - else None - ) - B = x.size(0) flex_attention_block_mask = ( create_causal_block_mask_for_flex_attention( - self.token_counter, self.kv_length, B + token_counter, self.kv_length, B ) if self.config.attention_backend == AttentionBackend.FLEX else None @@ -192,8 +171,9 @@ def forward( self.cos, self.sin, phase, - self.token_counter, + token_counter, block_table, + actual_sequence_lengths, flex_attention_block_mask, ) if self.config.tensor_parallel: @@ -207,15 +187,12 @@ def forward( torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping ) - self.token_counter[:B].add_( - T if actual_sequence_lengths is None else actual_sequence_lengths - ) - if self.config.use_paged_kv_caching: - # NOTE: Paged KV: readjusting the token counters of the block table - # to exclude padded tokens. - # we can exclude this for generation - torch.ops.yalis.force_update_tokens_assigned_( - self.tokens_assigned[:B], self.token_counter[:B] + + if update_token_counter: + self.token_counter[:B].add_( + T + if actual_sequence_lengths is None + else actual_sequence_lengths ) return {"logits": x} @@ -305,16 +282,6 @@ def set_kv_cache( self.kv_length = max_seq_length self.max_batch_size = max_batch_size - max_tokens = max_seq_length * max_batch_size - - # TODO (Prajwal): This is a hack to not over allocated - # KV-cache by default.Fix with dynamic page calculation logic - global NUM_BLOCKS - if self.config.use_paged_kv_caching: - if max_tokens > PAGE_BLOCK_SIZE * NUM_BLOCKS: - print("Increasing NUM_BLOCKS to 1024") - NUM_BLOCKS = 1024 - # initialize the kv cache for all blocks for block in self.transformer.h: block.attn.kv_cache = block.attn.build_kv_cache( @@ -324,20 +291,6 @@ def set_kv_cache( device, dtype, ) - if self.config.use_paged_kv_caching: - self.kv_cache_manager = KVCacheManager( - max_batch_size, - 16384 // PAGE_BLOCK_SIZE, # ToDo: set these dynamically - NUM_BLOCKS, - PAGE_BLOCK_SIZE, - ) - # TODO: move to separate Python class - self.tokens_assigned = ( - self.kv_cache_manager.tokens_assigned_tensor() - ) - self.kvcache_block_table = self.kv_cache_manager.block_table() - self.kvcache_free_pages = self.kv_cache_manager.free_pages_tensor() - self.kvcache_next_page = self.kv_cache_manager.next_page_tensor() self.token_counter = torch.zeros( max_batch_size, device=device, dtype=torch.int32 @@ -434,6 +387,7 @@ def forward( phase: EnginePhase, token_counter: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, + actual_sequence_lengths: Optional[torch.Tensor] = None, flex_attention_block_mask=None, ) -> torch.Tensor: """ @@ -465,6 +419,7 @@ def forward( phase, token_counter, block_table, + actual_sequence_lengths, flex_attention_block_mask, ) attention_output = self.post_attention_norm(attention_output) @@ -575,6 +530,7 @@ def forward( phase: EnginePhase, token_counter: torch.Tensor, block_table: torch.Tensor = None, + actual_sequence_lengths: torch.Tensor = None, flex_attention_block_mask=None, ) -> torch.Tensor: B, T, C = ( @@ -631,6 +587,7 @@ def forward( v=v, phase=phase, cache_seqlens=token_counter, + actual_seqlens=actual_sequence_lengths, block_table=block_table, rotary_cos=cos, rotary_sin=sin,