Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/tokenspeed/runtime/configs/deepseek_v4_cache_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,43 +78,51 @@ def build_v4_cache_specs(
unique_compress_ratios = sorted({int(r) for r in layer_ratio if int(r) > 1})

specs: List[PagedCacheGroupSpec] = [
# SWA kv: trailing window only -> State family.
PagedCacheGroupSpec(
group_id=V4_SWA_KV_GROUP_ID,
retention="sliding_window",
rows_per_page=V4_KERNEL_BLOCK_ROWS,
entry_stride_tokens=1,
sliding_window_tokens=swa_window,
family="state",
),
]
for ratio in unique_compress_ratios:
if ratio not in _COMPRESSOR_STATE_WINDOW_TOKENS:
raise ValueError(f"unsupported DeepSeek V4 compress_ratio={ratio}")
# Compressor state: tail buffer -> State family.
specs.append(
PagedCacheGroupSpec(
group_id=v4_compressor_state_group_id(ratio),
retention="sliding_window",
rows_per_page=_COMPRESSOR_STATE_ROWS_PER_PAGE[ratio],
entry_stride_tokens=1,
sliding_window_tokens=_COMPRESSOR_STATE_WINDOW_TOKENS[ratio],
family="state",
)
)
# Compressed kv: full-history chain (indexer K shares this group).
specs.append(
PagedCacheGroupSpec(
group_id=v4_compressed_kv_group_id(ratio),
retention="full_history",
rows_per_page=_compressed_kernel_block_size(ratio),
entry_stride_tokens=ratio,
sliding_window_tokens=None,
family="history",
)
)
if 4 in unique_compress_ratios:
# Indexer compressor state: tail buffer -> State family.
specs.append(
PagedCacheGroupSpec(
group_id=V4_INDEXER_COMPRESSOR_STATE_GROUP_ID,
retention="sliding_window",
rows_per_page=_COMPRESSOR_STATE_ROWS_PER_PAGE[4],
entry_stride_tokens=1,
sliding_window_tokens=_COMPRESSOR_STATE_WINDOW_TOKENS[4],
family="state",
)
)
return specs
Expand Down
3 changes: 3 additions & 0 deletions python/tokenspeed/runtime/configs/paged_cache_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tokenspeed.runtime.utils.common import ceil_div

Retention = Literal["full_history", "sliding_window"]
Family = Literal["history", "state"]


@dataclass(frozen=True)
Expand All @@ -31,6 +32,8 @@ class PagedCacheGroupSpec:
rows_per_page: int
entry_stride_tokens: int
sliding_window_tokens: Optional[int]
# History groups form a chain; State groups only need the trailing window.
family: Family = "history"


_PAGED_CACHE_GROUP_DUMMY_PAGES = 1
Expand Down
10 changes: 9 additions & 1 deletion python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
cache_sync_debug_enabled,
make_config,
pool_to_paged_cache_groups,
pool_to_prefix_cache_adjunct_spec,
pop_common_cache_event_payloads,
)
from tokenspeed.runtime.execution.distributed_initializer import (
Expand Down Expand Up @@ -275,6 +276,12 @@ def __init__(
enable_mixed_prefill_decode = (
server_args.enable_mixed_batch and server_args.speculative_algorithm is None
)
# Adjunct enabled only when pool opts in AND prefix-caching switch is on.
paged_cache_groups = pool_to_paged_cache_groups(token_to_kv_pool)
prefix_cache_adjunct = None
required_groups = token_to_kv_pool.prefix_cache_required_group_ids
if required_groups is not None and server_args.enable_prefix_caching:
prefix_cache_adjunct = pool_to_prefix_cache_adjunct_spec(required_groups)
scheduler_cfg = make_config(
num_device_pages=self.max_total_num_tokens // server_args.block_size,
max_scheduled_tokens=server_args.chunked_prefill_size,
Expand All @@ -295,8 +302,9 @@ def __init__(
enable_mamba=has_mamba,
mamba_cache_chunk_size=server_args.mamba_cache_chunk_size,
mamba_pool_total_chunks=mamba_pool_total_chunks,
paged_cache_groups=pool_to_paged_cache_groups(token_to_kv_pool),
paged_cache_groups=paged_cache_groups,
enable_mixed_prefill_decode=enable_mixed_prefill_decode,
prefix_cache_adjunct=prefix_cache_adjunct,
)
logger.info(
"Scheduler config: page_size=%s num_device_pages=%s "
Expand Down
32 changes: 31 additions & 1 deletion python/tokenspeed/runtime/engine/scheduler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
ExecutionEvent,
ForwardEvent,
PagedCacheGroupConfig,
PagedCacheGroupFamily,
PagedCacheRetention,
PrefixCacheAdjunctSpec,
RequestSpec,
SchedulerConfig,
)
Expand Down Expand Up @@ -67,6 +69,7 @@ def make_config(
mamba_pool_total_chunks: int = 0,
paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None,
enable_mixed_prefill_decode: bool = False,
prefix_cache_adjunct: "PrefixCacheAdjunctSpec | None" = None,
) -> SchedulerConfig:
cfg = SchedulerConfig()
cfg.num_device_pages = num_device_pages
Expand Down Expand Up @@ -96,12 +99,15 @@ def make_config(
cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode
if paged_cache_groups:
cfg.paged_cache_groups = list(paged_cache_groups)
# Opt-in; unset means paged-cache groups are transport-only.
if prefix_cache_adjunct is not None:
cfg.prefix_cache_adjunct = prefix_cache_adjunct
return cfg


def pool_to_paged_cache_groups(pool: Any) -> list:
"""Convert a KV pool's paged_cache_group_specs to scheduler configs."""
specs = getattr(pool, "paged_cache_group_specs", ())
specs = pool.paged_cache_group_specs
if not specs:
return []
counts = pool.paged_cache_group_page_counts
Expand All @@ -116,19 +122,43 @@ def pool_to_paged_cache_groups(pool: Any) -> list:
f"pool_to_paged_cache_groups: unsupported retention "
f"{spec.retention!r} for group {spec.group_id!r}"
)
family_str = getattr(spec, "family", "history")
if family_str == "history":
family = PagedCacheGroupFamily.History
elif family_str == "state":
family = PagedCacheGroupFamily.State
else:
raise ValueError(
f"pool_to_paged_cache_groups: unsupported family "
f"{family_str!r} for group {spec.group_id!r}"
)
kwargs = dict(
group_id=spec.group_id,
rows_per_page=int(spec.rows_per_page),
entry_stride_tokens=int(spec.entry_stride_tokens),
total_pages=int(counts[spec.group_id]),
retention=retention,
family=family,
)
if spec.retention == "sliding_window":
kwargs["sliding_window_tokens"] = int(spec.sliding_window_tokens)
out.append(PagedCacheGroupConfig(**kwargs))
return out


def pool_to_prefix_cache_adjunct_spec(
required_group_ids: Sequence[str],
) -> "PrefixCacheAdjunctSpec":
"""Build a PrefixCacheAdjunctSpec from a non-empty required-group-id list."""
if not required_group_ids:
raise ValueError(
"pool_to_prefix_cache_adjunct_spec: required_group_ids must be non-empty"
)
spec = PrefixCacheAdjunctSpec()
spec.required_groups = [str(gid) for gid in required_group_ids]
return spec


def make_extend_result_event(request_id: str, tokens: list[int] = ()) -> None:
fe = ForwardEvent.ExtendResult()
fe.request_id = request_id
Expand Down
30 changes: 13 additions & 17 deletions python/tokenspeed/runtime/execution/cuda_graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@ def __init__(
self.world_size = config.world_size
# Backends alias their cache_seqlens buffer. Draft backend aliases
# the drafter-owned draft_seq_lens to keep InputBuffers read-only.
paged_cache_group_specs = tuple(
getattr(token_to_kv_pool, "paged_cache_group_specs", ()) or ()
)
paged_cache_group_specs = tuple(token_to_kv_pool.paged_cache_group_specs)
try:
attn_backend.init_cuda_graph_state(
self.max_bs,
Expand All @@ -241,7 +239,7 @@ def __init__(
)
if draft_attn_backend is not None:
draft_paged_cache_group_specs = tuple(
getattr(draft_token_to_kv_pool, "paged_cache_group_specs", ()) or ()
draft_token_to_kv_pool.paged_cache_group_specs
)
try:
draft_attn_backend.init_cuda_graph_state(
Expand Down Expand Up @@ -416,7 +414,7 @@ def run_once():
return graph, out

def _capture_paged_cache_block_tables(self, bs: int, pool) -> dict | None:
specs = tuple(getattr(pool, "paged_cache_group_specs", ()) or ())
specs = tuple(pool.paged_cache_group_specs)
if not specs:
return None
out = {}
Expand Down Expand Up @@ -447,10 +445,9 @@ def _init_capture_metadata(self, bs: int):
bs,
self.token_to_kv_pool,
)
if paged_cache_block_tables is not None and getattr(
self.attn_backend,
"uses_paged_cache_groups",
False,
if (
paged_cache_block_tables is not None
and self.attn_backend.uses_paged_cache_groups
):
capture_kwargs["paged_cache_block_tables"] = paged_cache_block_tables
self.attn_backend.init_forward_metadata_capture_cuda_graph(
Expand All @@ -468,10 +465,9 @@ def _init_capture_metadata(self, bs: int):
bs,
self.draft_token_to_kv_pool,
)
if draft_paged_cache_block_tables is not None and getattr(
self.draft_attn_backend,
"uses_paged_cache_groups",
False,
if (
draft_paged_cache_block_tables is not None
and self.draft_attn_backend.uses_paged_cache_groups
):
draft_kwargs["paged_cache_block_tables"] = (
draft_paged_cache_block_tables
Expand Down Expand Up @@ -581,7 +577,7 @@ def _init_replay_metadata(
kwargs["paged_cache_block_table_base_offsets"] = (
paged_cache_block_table_base_offsets
)
if getattr(self.attn_backend, "uses_padded_decode_token_mask", False):
if self.attn_backend.uses_padded_decode_token_mask:
kwargs["actual_bs"] = actual_bs
self.attn_backend.init_forward_metadata_replay_cuda_graph(
padded_bs,
Expand Down Expand Up @@ -790,7 +786,7 @@ def __call__(
if (
bs == 0
and paged_cache_block_tables is None
and getattr(self.attn_backend, "uses_paged_cache_groups", False)
and self.attn_backend.uses_paged_cache_groups
):
paged_cache_block_tables = self._capture_paged_cache_block_tables(
padded_bs,
Expand Down Expand Up @@ -855,12 +851,12 @@ def __call__(
spec_info=spec_info,
paged_cache_block_tables=(
paged_cache_block_tables
if getattr(self.attn_backend, "uses_paged_cache_groups", False)
if self.attn_backend.uses_paged_cache_groups
else None
),
paged_cache_block_table_base_offsets=(
paged_cache_block_table_base_offsets
if getattr(self.attn_backend, "uses_paged_cache_groups", False)
if self.attn_backend.uses_paged_cache_groups
else None
),
**mamba_kwargs,
Expand Down
3 changes: 3 additions & 0 deletions python/tokenspeed/runtime/layers/attention/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
class AttentionBackend(ABC):
"""The base class of attention backends"""

uses_paged_cache_groups: bool = False
uses_padded_decode_token_mask: bool = False

def __init__(self, config: BaseAttnConfig) -> None:
self.device = config.device
self.num_qo_heads = config.num_attention_heads // config.attn_tp_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1419,8 +1419,8 @@ def _refresh_cuda_graph_base_offsets(
) -> dict[str, torch.Tensor]:
"""Refresh persistent base-offset buffers from per-step input.

Sliding groups whose key is missing fall back to 0 (legacy
absolute scheduler binding). Returns the [:bs] views keyed by gid.
Sliding groups whose key is missing fall back to 0. Returns the [:bs]
views keyed by gid.
"""
out: dict[str, torch.Tensor] = {}
for gid, buf in self._cuda_graph_paged_cache_base_offsets.items():
Expand Down
9 changes: 9 additions & 0 deletions python/tokenspeed/runtime/layers/attention/kv_cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import torch

from tokenspeed.runtime.configs.paged_cache_spec import PagedCacheGroupSpec
from tokenspeed.runtime.layers.paged_attention import PagedAttention
from tokenspeed.runtime.utils import get_colorful_logger

Expand All @@ -36,6 +37,9 @@
class BaseTokenToKVPool:
"""A memory pool that maps a token location to its kv cache data."""

paged_cache_group_specs: tuple[PagedCacheGroupSpec, ...] = ()
paged_cache_group_page_counts: dict[str, int] = {}

def __init__(
self,
size: int,
Expand Down Expand Up @@ -106,6 +110,11 @@ def load_cpu_copy(
) -> None:
raise NotImplementedError()

@property
def prefix_cache_required_group_ids(self) -> tuple[str, ...] | None:
"""None means adjunct disabled; subclasses return required group ids."""
return None

# Buffer metadata used by prefill/decode disaggregation.
def get_contiguous_buf_infos(self):
raise NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,7 @@ def _split_paged_cache_block_tables_into_v4_metadata(

Returns (swa, {ratio: compressor_state}, indexer_state, swa_base,
{ratio: compressor_state_base}, indexer_state_base). Unknown group ids
are ignored. Base offsets are None / missing when the input lacks them
(legacy scheduler binding).
are ignored. Base offsets are None / missing when the input lacks them.
"""
offsets = paged_cache_block_table_base_offsets or {}
swa = paged_cache_block_tables.get(V4_SWA_KV_GROUP_ID)
Expand Down Expand Up @@ -593,11 +592,12 @@ def deepseek_v4_cache_layout_from_config(
class DeepseekV4TokenToKVPool(BaseTokenToKVPool):
"""DeepSeek V4 fp8_ds_mla cache pool.

TokenSpeed keeps the SWA, compressed, compressor-state, and CSA indexer
caches in one V4-only pool so ordinary MLA models keep their existing cache
contract untouched. Compressed caches currently reuse the request page table;
this is correctness-first and leaves ratio-specific allocation for the
optimized follow-up.
TokenSpeed keeps SWA, compressed, compressor-state, and CSA indexer caches
in dedicated per-group paged pools (see PagedCacheGroup* on the scheduler
side and ``build_v4_cache_specs`` here), keeping ordinary MLA models on
their existing single-pool contract. The ``indexer_kv_buffer`` shares its
page table and page-count budget with the ``v4.c{ratio}a.compressed_kv``
group rather than owning a separate group of its own.
"""

def __init__(
Expand Down Expand Up @@ -790,6 +790,11 @@ def _group_pages(group_id: str, default: int) -> int:
self.compressed_block_sizes,
)

@property
def prefix_cache_required_group_ids(self) -> tuple[str, ...]:
"""All V4 paged-cache groups must be present for a snapshot to be complete."""
return tuple(str(spec.group_id) for spec in self.paged_cache_group_specs)

def _require(
self, buffers: list[torch.Tensor | None], layer_id: int, name: str
) -> torch.Tensor:
Expand Down
Loading
Loading