Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/basic_correctness/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def yalis_engine(model_id, dtype, attn_backend):
"""Create a standard Yalis LLMEngine."""
model_config = ModelConfig(model_name=model_id, precision=dtype.yalis)
inference_config = InferenceConfig(
max_batch_size=4,
max_batch_size=8,
max_length_of_generated_sequences=2048,
top_p=0.0,
temperature=0.0,
tp_dims=None,
attention_backend=attn_backend.yalis,
use_paged_kv_caching=False,
use_paged_kv_caching=(attn_backend.yalis == "flash"),
)
return LLMEngine(
model_config=model_config, inference_config=inference_config
Expand Down
3 changes: 2 additions & 1 deletion yalis/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from yalis.constants import EnginePhase
from .registry import get_attention
from .backends import AttentionBackend
from .masking import create_block_mask # noqa: F401


def attention_wrapper(
Expand All @@ -15,6 +14,7 @@ def attention_wrapper(
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
cache_seqlens: Optional[torch.Tensor] = None,
actual_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
Expand All @@ -32,6 +32,7 @@ def attention_wrapper(
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
actual_seqlens=actual_seqlens,
block_table=block_table,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
Expand Down
Empty file.
18 changes: 14 additions & 4 deletions yalis/attention/flash.py → yalis/attention/backend_impl/flash.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from flash_attn import flash_attn_with_kvcache
import torch
from typing import Sequence, Optional
from .registry import register_attention
from .update_kv_cache import update_paged_kv_cache
import torch

from flash_attn import flash_attn_with_kvcache
from flash_attn.ops.triton.rotary import apply_rotary
from yalis.attention.utils.flash_utils import update_paged_kv_cache
from yalis.attention.registry import register_attention
from yalis.constants import EnginePhase


Expand Down Expand Up @@ -93,6 +94,7 @@ def flash_attention(
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
cache_seqlens: Optional[torch.Tensor] = None,
actual_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -145,6 +147,7 @@ def flash_attention(
v=v,
block_table=block_table,
cache_seq_len=cache_seqlens,
actual_seqlens=actual_seqlens,
k_cache=k_cache,
v_cache=v_cache,
)
Expand All @@ -153,6 +156,13 @@ def flash_attention(
# subsequent layers to update their kv-caches.
cache_seqlens = cache_seqlens + T

if phase == EnginePhase.PREFILL:
# This is clever way for now to not have to pad the actual KV-Cache
# and just use the k and v tensors directly in prefill
k_cache = k
v_cache = v
block_table = None

k, v = None, None

return torch_compile_compatible_flash_attention(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from axonn import axonn as ax
from axonn.intra_layer.communication import Drop, Gather

from .registry import register_attention
from yalis.attention.registry import register_attention
from yalis.constants import EnginePhase


Expand Down
6 changes: 4 additions & 2 deletions yalis/attention/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# These imports trigger @register_attention decorators
from . import sdpa_and_flex # noqa: F401
from . import flash # noqa: F401
from yalis.attention.backend_impl.sdpa_and_flex import ( # noqa: F401
sdpa_attention,
)
from yalis.attention.backend_impl.flash import flash_attention # noqa: F401
Comment thread
prajwal1210 marked this conversation as resolved.

from enum import Enum

Expand Down
15 changes: 15 additions & 0 deletions yalis/attention/kv_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .kv_cache_policy import (
KVCachePolicy,
ContiguousKVCachePolicy,
PagedKVCachePolicy,
)
from .kv_slots_manager import KVSlotsManager
from .slot_allocator import SlotAllocator

__all__ = [
"KVCachePolicy",
"ContiguousKVCachePolicy",
"PagedKVCachePolicy",
"KVSlotsManager",
"SlotAllocator",
]
180 changes: 180 additions & 0 deletions yalis/attention/kv_cache/kv_cache_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Optional, Protocol, Union, List

import torch


class KVCachePolicy(Protocol):
def allocate(
self, slot_ids: List[int], prompt_lengths: torch.Tensor
) -> None: ...

def update(
self, slot_ids: List[int], n_new_tokens: Union[int, List[int]]
) -> None: ...

def release(self, slot_ids: List[int]) -> None: ...

def view(self, slot_ids: List[int]) -> Optional[torch.Tensor]: ...

def reset(self) -> None: ...

def lengths(self, slot_ids: List[int]) -> Optional[torch.Tensor]:
"""
Return per-row sequence lengths for provided slot ids.
Returns None if the policy does not track lengths internally.
"""
...


class ContiguousKVCachePolicy:
"""
No-op manager for contiguous KV-cache layout.
Engine can call allocate/update/view to keep a uniform API.
Tracks per-slot sequence lengths locally.
"""

def __init__(
self, capacity: int, device: Optional[torch.device] = None
) -> None:
self._seq_lens = torch.zeros(
capacity, dtype=torch.int32, device="cuda"
)

def allocate(
self, slot_ids: List[int], prompt_lengths: torch.Tensor
) -> Optional[torch.Tensor]:
if prompt_lengths.numel() == 0 or len(slot_ids) == 0:
return None
index = torch.tensor(
slot_ids, dtype=torch.int64, device=self._seq_lens.device
)
self._seq_lens.index_copy_(
0,
index,
prompt_lengths.to(dtype=torch.int32, device=self._seq_lens.device),
)
return None

def update(
self, slot_ids: List[int], n_new_tokens: Union[int, List[int]]
) -> Optional[torch.Tensor]:
if len(slot_ids) == 0:
return None
index = torch.tensor(
slot_ids, dtype=torch.int64, device=self._seq_lens.device
)
if isinstance(n_new_tokens, int):
self._seq_lens.index_add_(
0,
index,
torch.full(
(len(slot_ids),),
int(n_new_tokens),
dtype=torch.int32,
device=self._seq_lens.device,
),
)
else:
dt = torch.tensor(
n_new_tokens, dtype=torch.int32, device=self._seq_lens.device
)
self._seq_lens.index_add_(0, index, dt)
return None

def release(self, slot_ids: List[int]) -> None:
if len(slot_ids) == 0:
return None
index = torch.tensor(
slot_ids, dtype=torch.int64, device=self._seq_lens.device
)
self._seq_lens.index_fill_(0, index, 0)
return None

def view(self, slot_ids: List[int]) -> Optional[torch.Tensor]:
return None

def reset(self) -> None:
self._seq_lens.zero_()
return None

def lengths(self, slot_ids: List[int]) -> Optional[torch.Tensor]:
index = torch.tensor(
slot_ids, dtype=torch.int64, device=self._seq_lens.device
)
return self._seq_lens.index_select(0, index)


class PagedKVCachePolicy:
"""
Thin Python wrapper over the C++ paged KV cache allocator (kvcache_manager)
Owned by the Engine. Provides a stable API for page/block table management
"""

def __init__(
self,
batch_size: int,
max_num_blocks_per_seq: int,
num_blocks: int,
page_block_size: int,
verbose: bool = False,
) -> None:
# Lazy import
from kvcache_manager import KVCacheManager as _CppKVCacheManager

self._impl = _CppKVCacheManager(
batch_size,
max_num_blocks_per_seq,
num_blocks,
page_block_size,
)
self._verbose = verbose

def block_table(self) -> torch.Tensor:
return self._impl.block_table()

def allocate(
self, slot_ids: List[int], prompt_lengths: torch.Tensor
) -> None:
for i, slot in enumerate(slot_ids):
self._impl.allocate_sequence(
int(slot), int(prompt_lengths[i].item())
)

def update(
self, slot_ids: List[int], n_new_tokens: Union[int, List[int]]
) -> None:
if isinstance(n_new_tokens, int):
for slot in slot_ids:
self._impl.extend_sequence(int(slot), int(n_new_tokens))
elif isinstance(n_new_tokens, list):
if len(n_new_tokens) != len(slot_ids):
raise ValueError(
"n_new_tokens list must match slot_ids length"
)
for slot, dt in zip(slot_ids, n_new_tokens):
self._impl.extend_sequence(int(slot), int(dt))
else:
raise TypeError("n_new_tokens must be int or list[int]")

def release(self, slot_ids: List[int]) -> None:
for slot in slot_ids:
self._impl.free_sequence(int(slot))

def view(self, slot_ids: List[int]) -> torch.Tensor:
bt = self._impl.block_table()
index = torch.tensor(slot_ids, dtype=torch.int64, device=bt.device)
return bt.index_select(0, index)

def reset(self) -> None:
self._impl.reset()

def lengths(self, slot_ids: List[int]) -> torch.Tensor:
"""
Return current per-row token counts from the C++ manager (CPU int64),
indexed by the provided slot ids.
"""
all_counts = self._impl.tokens_assigned()
index = torch.tensor(
slot_ids, dtype=torch.int32, device=all_counts.device
)
return all_counts.index_select(0, index)
Loading