Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/serving/deepseek-v4.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \
--max-model-len 4096 \
--max-total-tokens 16384 \
--chunked-prefill-size 8192 \
--enable-mixed-batch \
--gpu-memory-utilization 0.9 \
--disable-kvstore
```
Expand All @@ -29,6 +30,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \
| `--kv-cache-dtype fp8_e4m3` | V4 SWA cache rows are uint8-packed FP8 NoPE + BF16 RoPE + UE8M0 scale; FP8 e4m3 is the only supported KV dtype. |
| `--moe-backend mega_moe` | Activates the DeepGEMM `fp8_fp4_mega_moe` fused experts. Requires `tokenspeed-deepgemm>=2.5.0.post20260424`. |
| `--attention-use-fp4-indexer-cache` | Stores indexer keys as MXFP4 (`[values \| ue8m0 scales]`); the FP8 fallback path is reference-only. |
| `--enable-mixed-batch` | Enables mixed prefill/decode scheduling for V4 sparse attention. It is off by default globally because other backend paths do not all support mixed batches yet. |
| `--trust-remote-code` | The HF config uses model-class architectures registered via remote code. |

## Block size
Expand Down
20 changes: 13 additions & 7 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def __init__(
f"(ratio={server_args.mamba_full_memory_ratio})."
)

enable_mixed_prefill_decode = (
server_args.enable_mixed_batch and server_args.speculative_algorithm is None
)
scheduler_cfg = make_config(
num_device_pages=self.max_total_num_tokens // server_args.block_size,
max_scheduled_tokens=server_args.chunked_prefill_size,
Expand All @@ -293,6 +296,7 @@ def __init__(
mamba_cache_chunk_size=server_args.mamba_cache_chunk_size,
mamba_pool_total_chunks=mamba_pool_total_chunks,
paged_cache_groups=pool_to_paged_cache_groups(token_to_kv_pool),
enable_mixed_prefill_decode=enable_mixed_prefill_decode,
)
logger.info(
"Scheduler config: page_size=%s num_device_pages=%s "
Expand Down Expand Up @@ -785,8 +789,10 @@ def _commit_forward_results(
on_first_token=None,
):
self.request_handler.forward_ct += 1
forward_mode = (
ForwardMode.EXTEND if forward_op.num_extends() > 0 else ForwardMode.DECODE
forward_mode = ForwardMode.from_num_extends(
forward_op.num_extends(),
len(forward_op.request_ids),
has_drafter=self.server_args.speculative_algorithm is not None,
)
self.request_handler._profile_batch_predicate(forward_mode)

Expand Down Expand Up @@ -859,12 +865,12 @@ def _dp_sync_and_check(self, forward_op) -> DpForwardMetadata:
batch_size = len(forward_op.request_ids) if forward_op is not None else 0
if forward_op is None:
forward_mode = ForwardMode.IDLE
elif forward_op.num_extends() > 0:
forward_mode = ForwardMode.EXTEND
elif self.server_args.speculative_algorithm is not None:
forward_mode = ForwardMode.TARGET_VERIFY
else:
forward_mode = ForwardMode.DECODE
forward_mode = ForwardMode.from_num_extends(
forward_op.num_extends(),
batch_size,
has_drafter=self.server_args.speculative_algorithm is not None,
)

self._dp_local_info[0, 0] = num_tokens
self._dp_local_info[0, 1] = batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def post_process_forward_op(
forward_op.input_lengths,
forward_op.extend_prefix_lens,
)
is_decode_op = forward_op.num_extends() <= 0
num_extends = forward_op.num_extends()
is_decode_op = num_extends <= 0

request_changes = []
stream_out_rids = []
Expand All @@ -504,6 +505,7 @@ def post_process_forward_op(
if output_logprobs_list is not None
else None
)
is_decode_slot = i >= num_extends
if self.spec_num_tokens is not None and is_decode_op:
pt += self.spec_num_tokens
else:
Expand All @@ -524,7 +526,7 @@ def post_process_forward_op(
if on_first_token is not None and model_output_ids:
on_first_token(forward_op.request_pool_indices[i], model_output_ids[0])

if is_decode_op and self.spec_algorithm is not None:
if is_decode_slot and self.spec_algorithm is not None:
request_state.spec_verify_ct += 1

# With the capturable grammar pipeline the matcher is
Expand Down Expand Up @@ -597,7 +599,7 @@ def post_process_forward_op(
else:
stream_out_rids.append(rid)
stream_out_states.append(request_state)
if is_decode_op:
if is_decode_slot:
request_changes.append(
make_update_reserve_tokens_event(rid, output_length)
)
Expand Down
9 changes: 9 additions & 0 deletions python/tokenspeed/runtime/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def handle_batch_output(
if recv_obj.output_multi_ids is not None:
output_multi_ids = recv_obj.output_multi_ids[i]

if len(recv_obj.batch_accept_draft_tokens) > 0:
meta_info.update(
{
"accept_draft_tokens": recv_obj.batch_accept_draft_tokens[
i
]
}
)

out_dict = {
"text": state.text,
"output_ids": output_token_ids,
Expand Down
2 changes: 2 additions & 0 deletions python/tokenspeed/runtime/engine/scheduler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def make_config(
mamba_cache_chunk_size: int = 64,
mamba_pool_total_chunks: int = 0,
paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None,
enable_mixed_prefill_decode: bool = False,
) -> SchedulerConfig:
cfg = SchedulerConfig()
cfg.num_device_pages = num_device_pages
Expand All @@ -92,6 +93,7 @@ def make_config(
cfg.enable_mamba = enable_mamba
cfg.mamba_cache_chunk_size = mamba_cache_chunk_size
cfg.mamba_pool_total_chunks = mamba_pool_total_chunks
cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode
if paged_cache_groups:
cfg.paged_cache_groups = list(paged_cache_groups)
return cfg
Expand Down
9 changes: 7 additions & 2 deletions python/tokenspeed/runtime/execution/cuda_graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def _pad_offsets_to_padded_bs(
def _init_replay_metadata(
self,
padded_bs: int,
actual_bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
req_to_page: torch.Tensor,
Expand All @@ -562,7 +563,7 @@ def _init_replay_metadata(
"uses_paged_cache_groups",
False,
):
actual_bs = next(
table_bs = next(
(
int(table.shape[0])
for table in paged_cache_block_tables.values()
Expand All @@ -572,7 +573,7 @@ def _init_replay_metadata(
)
paged_cache_block_tables = self._pad_block_tables_to_padded_bs(
paged_cache_block_tables,
actual_bs=actual_bs,
actual_bs=table_bs,
padded_bs=padded_bs,
)
kwargs["paged_cache_block_tables"] = paged_cache_block_tables
Expand All @@ -585,6 +586,8 @@ def _init_replay_metadata(
kwargs["paged_cache_block_table_base_offsets"] = (
paged_cache_block_table_base_offsets
)
if getattr(self.attn_backend, "uses_padded_decode_token_mask", False):
kwargs["actual_bs"] = actual_bs
self.attn_backend.init_forward_metadata_replay_cuda_graph(
padded_bs,
req_pool_indices,
Expand Down Expand Up @@ -785,6 +788,7 @@ def __call__(
)
self._init_replay_metadata(
padded_bs,
bs,
req_pool_indices,
seq_lens,
req_to_page=req_to_page,
Expand Down Expand Up @@ -831,6 +835,7 @@ def __call__(
extend_prefix_lens_cpu=extend_prefix_lens_cpu,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
num_extends=ctx.num_extends,
positions=positions,
out_cache_loc=out_cache_loc,
global_num_tokens=ctx.global_num_tokens,
Expand Down
16 changes: 16 additions & 0 deletions python/tokenspeed/runtime/execution/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def is_extend(self):
def is_decode(self):
return self == ForwardMode.DECODE

def is_mixed(self):
return self == ForwardMode.MIXED

def is_idle(self):
return self == ForwardMode.IDLE

Expand All @@ -67,6 +70,19 @@ def is_draft_extend(self):
def is_decode_or_idle(self):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE

@staticmethod
def from_num_extends(
num_extends: int,
batch_size: int,
*,
has_drafter: bool = False,
) -> "ForwardMode":
if batch_size <= 0:
return ForwardMode.IDLE
if num_extends > 0:
return ForwardMode.MIXED if num_extends < batch_size else ForwardMode.EXTEND
return ForwardMode.TARGET_VERIFY if has_drafter else ForwardMode.DECODE


class CaptureHiddenMode(IntEnum):
NULL = auto()
Expand Down
59 changes: 49 additions & 10 deletions python/tokenspeed/runtime/execution/input_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,19 @@ def fill_input_buffers(
page_size=self.page_size,
)

valid_cache_lengths = runtime_states.valid_cache_lengths[
cached_prefix_lens = runtime_states.valid_cache_lengths[
self.req_pool_indices_buf[:batch_size]
]
# Compute positions
prefix_lens = (
self.extend_prefix_lens_buf[:num_extends]
if num_extends > 0
else valid_cache_lengths
)
# Compute positions. In mixed batches, prefill rows use their extend
# prefix lengths while decode rows use the current valid cache lengths.
prefill_prefix_lens = self.extend_prefix_lens_buf[:num_extends]
if num_extends == 0:
prefix_lens = cached_prefix_lens
elif num_extends == batch_size:
prefix_lens = prefill_prefix_lens
else:
prefix_lens = cached_prefix_lens.clone()
prefix_lens[:num_extends].copy_(prefill_prefix_lens)
positions, _ = compute_position_triton(
extend_prefix_lens=prefix_lens,
extend_seq_lens=input_lengths_device,
Expand All @@ -193,20 +197,55 @@ def fill_input_buffers(

# Determine input_ids and forward_mode
if num_extends > 0:
prefill_token_count = sum(forward_op.input_lengths[:num_extends])
input_ids_cpu = torch.tensor(
forward_op.input_ids, device="cpu", pin_memory=True
)
self.input_ids_buf[:total_tokens].copy_(
self.input_ids_buf[:prefill_token_count].copy_(
input_ids_cpu,
non_blocking=True,
)
shifted_ids_cpu = torch.tensor(
forward_op.shifted_input_ids, device="cpu", pin_memory=True
)
self.shifted_prefill_ids_buf[:total_tokens].copy_(
self.shifted_prefill_ids_buf[:prefill_token_count].copy_(
shifted_ids_cpu,
non_blocking=True,
)
if num_extends < batch_size:
decode_req_pool_indices = req_pool_indices_device[
num_extends:batch_size
]
if forward_op.decode_input_ids is not None:
decode_count = batch_size - num_extends
if len(forward_op.decode_input_ids) != decode_count:
raise RuntimeError(
"mixed forward decode_input_ids length mismatch: "
f"got {len(forward_op.decode_input_ids)}, "
f"expected {decode_count}"
)
decode_input_ids_tensor = torch.tensor(
forward_op.decode_input_ids,
dtype=torch.int32,
device="cpu",
pin_memory=True,
).to(req_pool_indices_device.device, non_blocking=True)
mask = (decode_input_ids_tensor != -1).unsqueeze(1)
slot = runtime_states.future_input_map[decode_req_pool_indices, :1]
runtime_states.future_input_map[decode_req_pool_indices, :1] = (
torch.where(mask, decode_input_ids_tensor.unsqueeze(1), slot)
)
decode_ids = runtime_states.future_input_map[
decode_req_pool_indices, :1
].flatten()
self.input_ids_buf[prefill_token_count:total_tokens].copy_(
decode_ids,
non_blocking=True,
)
self.shifted_prefill_ids_buf[prefill_token_count:total_tokens].copy_(
decode_ids,
non_blocking=True,
)
else:
# If the scheduler provides explicit decode input ids (!= -1), write
# them into future_input_map before reading, so that they take effect
Expand All @@ -230,7 +269,7 @@ def fill_input_buffers(
non_blocking=True,
)

self.seq_lens_buf[:batch_size].copy_(input_lengths_device + valid_cache_lengths)
self.seq_lens_buf[:batch_size].copy_(input_lengths_device + cached_prefix_lens)

# Reset positions beyond total_tokens to the dummy KV slot so that any
# CUDA graph replay with a larger (padded) batch size writes padding
Expand Down
12 changes: 5 additions & 7 deletions python/tokenspeed/runtime/execution/model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,14 +829,12 @@ def execute_forward_op(
total_tokens=total_tokens,
)

if num_extends > 0:
forward_mode = ForwardMode.EXTEND
elif self.drafter is not None:
forward_mode = ForwardMode.TARGET_VERIFY
else:
forward_mode = ForwardMode.DECODE

bs = len(forward_op.request_ids)
forward_mode = ForwardMode.from_num_extends(
num_extends,
bs,
has_drafter=self.drafter is not None,
)

if self.runtime_states.mamba_pool is not None and (
num_extends > 0 or has_retract
Expand Down
Loading
Loading