Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/serving/deepseek-v4.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \
| `--kv-cache-dtype fp8_e4m3` | V4 SWA cache rows are uint8-packed FP8 NoPE + BF16 RoPE + UE8M0 scale; FP8 e4m3 is the only supported KV dtype. |
| `--moe-backend mega_moe` | Activates the DeepGEMM `fp8_fp4_mega_moe` fused experts. Requires `tokenspeed-deepgemm>=2.5.0.post20260424`. |
| `--attention-use-fp4-indexer-cache` | Stores indexer keys as MXFP4 (`[values \| ue8m0 scales]`); the FP8 fallback path is reference-only. |
| `--enable-mixed-batch` | Enables mixed prefill/decode scheduling for V4 sparse attention. It is off by default globally because other backend paths do not all support mixed batches yet. |
| `--enable-mixed-batch` | Lets the scheduler issue prefill and decode requests in the same iteration. Off by default globally; opt in per workload. |
| `--trust-remote-code` | The HF config uses model-class architectures registered via remote code. |

## Parser defaults
Expand Down
5 changes: 1 addition & 4 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,6 @@ def __init__(
f"(ratio={server_args.mamba_full_memory_ratio})."
)

enable_mixed_prefill_decode = (
server_args.enable_mixed_batch and server_args.speculative_algorithm is None
)
# Adjunct enabled only when pool opts in AND prefix-caching switch is on.
paged_cache_groups = pool_to_paged_cache_groups(token_to_kv_pool)
prefix_cache_adjunct = None
Expand Down Expand Up @@ -303,7 +300,7 @@ def __init__(
mamba_cache_chunk_size=server_args.mamba_cache_chunk_size,
mamba_pool_total_chunks=mamba_pool_total_chunks,
paged_cache_groups=paged_cache_groups,
enable_mixed_prefill_decode=enable_mixed_prefill_decode,
enable_mixed_prefill_decode=server_args.enable_mixed_batch,
prefix_cache_adjunct=prefix_cache_adjunct,
)
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def post_process_forward_op(
forward_op.extend_prefix_lens,
)
num_extends = forward_op.num_extends()
is_decode_op = num_extends <= 0

request_changes = []
stream_out_rids = []
Expand All @@ -506,7 +505,7 @@ def post_process_forward_op(
else None
)
is_decode_slot = i >= num_extends
if self.spec_num_tokens is not None and is_decode_op:
if self.spec_num_tokens is not None and is_decode_slot:
pt += self.spec_num_tokens
else:
pt += output_length
Expand Down
4 changes: 1 addition & 3 deletions python/tokenspeed/runtime/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ class ForwardContext:
forward_mode: ForwardMode | None
req_to_page: torch.Tensor | None = None
capture_hidden_mode: CaptureHiddenMode | None = CaptureHiddenMode.NULL
padded_static_len: int = -1

# --- dp attention ---
global_num_tokens: list[int] | None = None
global_bs: list[int] | None = None
all_decode_or_idle: bool = False

# --- logits processor ---
keep_full_logits: bool = False
last_index_offsets: torch.Tensor | None = None
gather_ids: torch.Tensor | None = None
65 changes: 21 additions & 44 deletions python/tokenspeed/runtime/execution/cuda_graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def _capture_one(self, bs: int):
if self.drafter is not None
else CaptureHiddenMode.NULL
),
keep_full_logits=True,
)

# For DP mode, global_num_tokens must be set so that the MoE
Expand Down Expand Up @@ -452,7 +451,6 @@ def _init_capture_metadata(self, bs: int):
capture_kwargs["paged_cache_block_tables"] = paged_cache_block_tables
self.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
bs * self.max_tokens_per_req,
self.input_buffers.req_pool_indices_buf[:bs],
self.input_buffers.seq_lens_buf[:bs],
ForwardMode.DECODE,
Expand All @@ -475,7 +473,6 @@ def _init_capture_metadata(self, bs: int):
# Drafter mutates seq_lens_buf in place per step; backends alias.
self.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
bs * self.max_tokens_per_req,
self.input_buffers.req_pool_indices_buf[:bs],
self.input_buffers.seq_lens_buf[:bs],
ForwardMode.DECODE,
Expand Down Expand Up @@ -581,7 +578,6 @@ def _init_replay_metadata(
kwargs["actual_bs"] = actual_bs
self.attn_backend.init_forward_metadata_replay_cuda_graph(
padded_bs,
padded_bs * self.max_tokens_per_req,
req_pool_indices,
seq_lens,
req_to_page=req_to_page,
Expand All @@ -591,7 +587,6 @@ def _init_replay_metadata(
if self.draft_attn_backend is not None:
self.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
padded_bs,
padded_bs * self.max_tokens_per_req,
req_pool_indices,
seq_lens,
req_to_page=self.drafter.req_to_page,
Expand All @@ -603,6 +598,7 @@ def _init_replay_metadata(
def _init_forward_metadata(
self,
padded_bs: int,
num_extends: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
req_to_page: torch.Tensor,
Expand All @@ -611,10 +607,10 @@ def _init_forward_metadata(
):
"""Eager path — allocate/refresh metadata for the upcoming forward."""
self.attn_backend.init_forward_metadata(
padded_bs,
padded_bs * self.max_tokens_per_req,
req_pool_indices,
seq_lens,
bs=padded_bs,
num_extends=num_extends,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_page=req_to_page,
forward_mode=forward_mode,
**kwargs,
Expand All @@ -633,46 +629,28 @@ def _init_forward_metadata(
# ``seq_lens_k=seq_lens[:bs]``).
# So the kernel sees the value the drafter just wrote, without
# rebuilding metadata per step.
# The EXTEND-mode prefill init still uses the controller's
# ``seq_lens`` because that path computes ``cu_seqlens_k`` from it
# eagerly (cumsum at init time), not via aliasing.
# Pre-write the buffer with the controller's seq_lens so the
# prefill-side eager work (cumsum at init) and the live-aliased
# decode side both see correct values from step 0 onward. Each
# is_draft backend fills both prefill+decode metadata in this
# one call.
#
# TODO: relying on aliasing for correctness is fragile — a stray
# copy or a misrouted buffer silently produces wrong outputs.
# Move to an explicit per-call ``seq_lens`` contract
# so each kernel invocation carries its own value rather than
# reading through a tensor registered at init.
draft_seq_lens = self.drafter.draft_seq_lens_buf[:padded_bs]
if forward_mode.is_extend_or_mixed():
# Step 0 uses the caller's prefix kwargs; subsequent decode
# steps use one-token-per-request metadata. Populate each
# slot with its own init call.
self.draft_attn_backend.init_forward_metadata(
padded_bs,
padded_bs * self.max_tokens_per_req,
req_pool_indices,
seq_lens,
req_to_page=self.drafter.req_to_page,
forward_mode=forward_mode,
**kwargs,
)
self.draft_attn_backend.init_forward_metadata(
padded_bs,
padded_bs,
req_pool_indices,
draft_seq_lens,
req_to_page=self.drafter.req_to_page,
forward_mode=ForwardMode.DECODE,
)
else:
self.draft_attn_backend.init_forward_metadata(
padded_bs,
padded_bs * self.max_tokens_per_req,
req_pool_indices,
draft_seq_lens,
req_to_page=self.drafter.req_to_page,
forward_mode=ForwardMode.DECODE,
)
draft_seq_lens.copy_(seq_lens)
self.draft_attn_backend.init_forward_metadata(
bs=padded_bs,
num_extends=num_extends,
req_pool_indices=req_pool_indices,
seq_lens=draft_seq_lens,
Comment on lines +645 to +649
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reintroduce decode-metadata init for Triton draft backends

When _init_forward_metadata switched from two draft-backend init calls to a single call, EXTEND/MIXED draft batches no longer force a DECODE-shaped metadata refresh before Eagle step 2+. TritonAttnBackend stores only one forward_metadata, so after the single call it can still hold prefill-style qo_indptr/kv_indptr; then _run_multi_step_decode invokes decode kernels against that stale layout, which can misindex KV ranges or fail on shape assumptions for mixed prefill+decode speculative batches.

Useful? React with 👍 / 👎.

req_to_page=self.drafter.req_to_page,
forward_mode=forward_mode,
**kwargs,
)

def _global_graph_bs(self, ctx: ForwardContext) -> int | None:
if self.dp_size <= 1 or ctx.global_num_tokens is None:
Expand Down Expand Up @@ -832,6 +810,7 @@ def __call__(
else:
self._init_forward_metadata(
padded_bs,
ctx.num_extends,
req_pool_indices,
seq_lens,
req_to_page=req_to_page,
Expand All @@ -841,13 +820,11 @@ def __call__(
extend_prefix_lens_cpu=extend_prefix_lens_cpu,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
num_extends=ctx.num_extends,
positions=positions,
out_cache_loc=out_cache_loc,
global_num_tokens=ctx.global_num_tokens,
all_decode_or_idle=ctx.all_decode_or_idle,
capture_hidden_mode=ctx.capture_hidden_mode,
padded_static_len=ctx.padded_static_len,
spec_info=spec_info,
paged_cache_block_tables=(
paged_cache_block_tables
Expand Down
Loading
Loading