diff --git a/docs/serving/deepseek-v4.md b/docs/serving/deepseek-v4.md index bb71199f8..cd3bae91c 100644 --- a/docs/serving/deepseek-v4.md +++ b/docs/serving/deepseek-v4.md @@ -30,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \ | `--kv-cache-dtype fp8_e4m3` | V4 SWA cache rows are uint8-packed FP8 NoPE + BF16 RoPE + UE8M0 scale; FP8 e4m3 is the only supported KV dtype. | | `--moe-backend mega_moe` | Activates the DeepGEMM `fp8_fp4_mega_moe` fused experts. Requires `tokenspeed-deepgemm>=2.5.0.post20260424`. | | `--attention-use-fp4-indexer-cache` | Stores indexer keys as MXFP4 (`[values \| ue8m0 scales]`); the FP8 fallback path is reference-only. | -| `--enable-mixed-batch` | Enables mixed prefill/decode scheduling for V4 sparse attention. It is off by default globally because other backend paths do not all support mixed batches yet. | +| `--enable-mixed-batch` | Lets the scheduler issue prefill and decode requests in the same iteration. Off by default globally; opt in per workload. | | `--trust-remote-code` | The HF config uses model-class architectures registered via remote code. | ## Parser defaults diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index ca17e801e..e3ccfc069 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -273,9 +273,6 @@ def __init__( f"(ratio={server_args.mamba_full_memory_ratio})." ) - enable_mixed_prefill_decode = ( - server_args.enable_mixed_batch and server_args.speculative_algorithm is None - ) # Adjunct enabled only when pool opts in AND prefix-caching switch is on. paged_cache_groups = pool_to_paged_cache_groups(token_to_kv_pool) prefix_cache_adjunct = None @@ -303,7 +300,7 @@ def __init__( mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, paged_cache_groups=paged_cache_groups, - enable_mixed_prefill_decode=enable_mixed_prefill_decode, + enable_mixed_prefill_decode=server_args.enable_mixed_batch, prefix_cache_adjunct=prefix_cache_adjunct, ) logger.info( diff --git a/python/tokenspeed/runtime/engine/generation_output_processor.py b/python/tokenspeed/runtime/engine/generation_output_processor.py index 2bcb1891c..39ae86389 100644 --- a/python/tokenspeed/runtime/engine/generation_output_processor.py +++ b/python/tokenspeed/runtime/engine/generation_output_processor.py @@ -484,7 +484,6 @@ def post_process_forward_op( forward_op.extend_prefix_lens, ) num_extends = forward_op.num_extends() - is_decode_op = num_extends <= 0 request_changes = [] stream_out_rids = [] @@ -506,7 +505,7 @@ def post_process_forward_op( else None ) is_decode_slot = i >= num_extends - if self.spec_num_tokens is not None and is_decode_op: + if self.spec_num_tokens is not None and is_decode_slot: pt += self.spec_num_tokens else: pt += output_length diff --git a/python/tokenspeed/runtime/execution/context.py b/python/tokenspeed/runtime/execution/context.py index e4fe3f922..e5cb59f39 100644 --- a/python/tokenspeed/runtime/execution/context.py +++ b/python/tokenspeed/runtime/execution/context.py @@ -50,7 +50,6 @@ class ForwardContext: forward_mode: ForwardMode | None req_to_page: torch.Tensor | None = None capture_hidden_mode: CaptureHiddenMode | None = CaptureHiddenMode.NULL - padded_static_len: int = -1 # --- dp attention --- global_num_tokens: list[int] | None = None @@ -58,5 +57,4 @@ class ForwardContext: all_decode_or_idle: bool = False # --- logits processor --- - keep_full_logits: bool = False - last_index_offsets: torch.Tensor | None = None + gather_ids: torch.Tensor | None = None diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index 647664a09..9fa151db1 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -306,7 +306,6 @@ def _capture_one(self, bs: int): if self.drafter is not None else CaptureHiddenMode.NULL ), - keep_full_logits=True, ) # For DP mode, global_num_tokens must be set so that the MoE @@ -452,7 +451,6 @@ def _init_capture_metadata(self, bs: int): capture_kwargs["paged_cache_block_tables"] = paged_cache_block_tables self.attn_backend.init_forward_metadata_capture_cuda_graph( bs, - bs * self.max_tokens_per_req, self.input_buffers.req_pool_indices_buf[:bs], self.input_buffers.seq_lens_buf[:bs], ForwardMode.DECODE, @@ -475,7 +473,6 @@ def _init_capture_metadata(self, bs: int): # Drafter mutates seq_lens_buf in place per step; backends alias. self.draft_attn_backend.init_forward_metadata_capture_cuda_graph( bs, - bs * self.max_tokens_per_req, self.input_buffers.req_pool_indices_buf[:bs], self.input_buffers.seq_lens_buf[:bs], ForwardMode.DECODE, @@ -581,7 +578,6 @@ def _init_replay_metadata( kwargs["actual_bs"] = actual_bs self.attn_backend.init_forward_metadata_replay_cuda_graph( padded_bs, - padded_bs * self.max_tokens_per_req, req_pool_indices, seq_lens, req_to_page=req_to_page, @@ -591,7 +587,6 @@ def _init_replay_metadata( if self.draft_attn_backend is not None: self.draft_attn_backend.init_forward_metadata_replay_cuda_graph( padded_bs, - padded_bs * self.max_tokens_per_req, req_pool_indices, seq_lens, req_to_page=self.drafter.req_to_page, @@ -603,6 +598,7 @@ def _init_replay_metadata( def _init_forward_metadata( self, padded_bs: int, + num_extends: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, req_to_page: torch.Tensor, @@ -611,10 +607,10 @@ def _init_forward_metadata( ): """Eager path — allocate/refresh metadata for the upcoming forward.""" self.attn_backend.init_forward_metadata( - padded_bs, - padded_bs * self.max_tokens_per_req, - req_pool_indices, - seq_lens, + bs=padded_bs, + num_extends=num_extends, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, req_to_page=req_to_page, forward_mode=forward_mode, **kwargs, @@ -633,9 +629,11 @@ def _init_forward_metadata( # ``seq_lens_k=seq_lens[:bs]``). # So the kernel sees the value the drafter just wrote, without # rebuilding metadata per step. - # The EXTEND-mode prefill init still uses the controller's - # ``seq_lens`` because that path computes ``cu_seqlens_k`` from it - # eagerly (cumsum at init time), not via aliasing. + # Pre-write the buffer with the controller's seq_lens so the + # prefill-side eager work (cumsum at init) and the live-aliased + # decode side both see correct values from step 0 onward. Each + # is_draft backend fills both prefill+decode metadata in this + # one call. # # TODO: relying on aliasing for correctness is fragile — a stray # copy or a misrouted buffer silently produces wrong outputs. @@ -643,36 +641,16 @@ def _init_forward_metadata( # so each kernel invocation carries its own value rather than # reading through a tensor registered at init. draft_seq_lens = self.drafter.draft_seq_lens_buf[:padded_bs] - if forward_mode.is_extend_or_mixed(): - # Step 0 uses the caller's prefix kwargs; subsequent decode - # steps use one-token-per-request metadata. Populate each - # slot with its own init call. - self.draft_attn_backend.init_forward_metadata( - padded_bs, - padded_bs * self.max_tokens_per_req, - req_pool_indices, - seq_lens, - req_to_page=self.drafter.req_to_page, - forward_mode=forward_mode, - **kwargs, - ) - self.draft_attn_backend.init_forward_metadata( - padded_bs, - padded_bs, - req_pool_indices, - draft_seq_lens, - req_to_page=self.drafter.req_to_page, - forward_mode=ForwardMode.DECODE, - ) - else: - self.draft_attn_backend.init_forward_metadata( - padded_bs, - padded_bs * self.max_tokens_per_req, - req_pool_indices, - draft_seq_lens, - req_to_page=self.drafter.req_to_page, - forward_mode=ForwardMode.DECODE, - ) + draft_seq_lens.copy_(seq_lens) + self.draft_attn_backend.init_forward_metadata( + bs=padded_bs, + num_extends=num_extends, + req_pool_indices=req_pool_indices, + seq_lens=draft_seq_lens, + req_to_page=self.drafter.req_to_page, + forward_mode=forward_mode, + **kwargs, + ) def _global_graph_bs(self, ctx: ForwardContext) -> int | None: if self.dp_size <= 1 or ctx.global_num_tokens is None: @@ -832,6 +810,7 @@ def __call__( else: self._init_forward_metadata( padded_bs, + ctx.num_extends, req_pool_indices, seq_lens, req_to_page=req_to_page, @@ -841,13 +820,11 @@ def __call__( extend_prefix_lens_cpu=extend_prefix_lens_cpu, extend_seq_lens=extend_seq_lens, extend_seq_lens_cpu=extend_seq_lens_cpu, - num_extends=ctx.num_extends, positions=positions, out_cache_loc=out_cache_loc, global_num_tokens=ctx.global_num_tokens, all_decode_or_idle=ctx.all_decode_or_idle, capture_hidden_mode=ctx.capture_hidden_mode, - padded_static_len=ctx.padded_static_len, spec_info=spec_info, paged_cache_block_tables=( paged_cache_block_tables diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index a1989ea0a..1e15cb581 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -53,6 +53,7 @@ @dataclass class EagleDraftInput: input_num_tokens: int + num_extends: int forward_mode: ForwardMode base_model_output: torch.Tensor # [bs] accept_lengths: torch.Tensor # [bs] @@ -124,10 +125,9 @@ def __init__( device=self.device, ) - # Precomputed `arange(max_bs) * spec_num_tokens - 1`, sliced and passed - # via ForwardContext for the padded-static-len last-token selection in - # LogitsProcessor. - self.last_index_offsets_buf = ( + # Precomputed `arange(max_bs) * spec_num_tokens - 1` + # gather_ids = gather_ids_offsets + accept_lengths + self.padded_gather_ids_offsets_buf = ( torch.arange( self.input_buffers.max_bs, dtype=torch.int64, device=self.device ) @@ -145,30 +145,58 @@ def _map_hot(self, ids: torch.Tensor) -> torch.Tensor: def _get_first_step_input( self, - forward_mode: ForwardMode, draft_input: EagleDraftInput, bs: int, input_num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Returns (input_ids, unpadded_input_lengths) for the first draft step.""" - if forward_mode.is_extend(): - input_ids = self.input_buffers.shifted_prefill_ids_buf[ - :input_num_tokens - ].clone() - + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns (input_ids, unpadded_input_lengths, gather_ids) for the first draft step. + + The first-step input shape matches the base model's: ragged + ``[prefill_part || decode_part]`` under MIXED, full prefill chunks + under EXTEND, ``base_model_output`` directly under DECODE. + """ + num_extends = draft_input.num_extends + num_decodes = bs - num_extends + if num_extends > 0: + num_decode_tokens = num_decodes * self.spec_num_tokens + num_prefill_tokens = input_num_tokens - num_decode_tokens + + input_ids = self.input_buffers.shifted_prefill_ids_buf[:input_num_tokens] unpadded_input_lengths = self.input_buffers.input_lengths_buf[:bs] - req_boundaries = unpadded_input_lengths.cumsum(0) - 1 # [bs] - boundary_ids = input_ids[req_boundaries] - needs_fill = boundary_ids == -1 # [bs] - input_ids[req_boundaries] = torch.where( - needs_fill, draft_input.base_model_output[:bs], boundary_ids + if num_decodes > 0: + input_ids[num_prefill_tokens:].copy_( + draft_input.base_model_output[num_extends:] + ) + unpadded_input_lengths[num_extends:].copy_( + draft_input.accept_lengths[num_extends:] + ) + + last_indices = unpadded_input_lengths[:num_extends].cumsum(0) - 1 + last_input_ids = input_ids[last_indices] + input_ids[last_indices] = torch.where( + last_input_ids == -1, + draft_input.base_model_output[:num_extends], + last_input_ids, ) + gather_ids = last_indices + if num_decodes > 0: + gather_ids = torch.cat( + [ + gather_ids, + self.padded_gather_ids_offsets_buf[:num_decodes] + + draft_input.accept_lengths[num_extends:] + + num_prefill_tokens, + ] + ) else: input_ids = draft_input.base_model_output unpadded_input_lengths = draft_input.accept_lengths + gather_ids = ( + self.padded_gather_ids_offsets_buf[:bs] + draft_input.accept_lengths + ) - return input_ids, unpadded_input_lengths + return input_ids, unpadded_input_lengths, gather_ids @nvtx_range("draft_first_step", color="purple") def _run_first_step( @@ -180,34 +208,27 @@ def _run_first_step( buffers = self.input_buffers forward_mode = draft_input.forward_mode - input_ids, unpadded_input_lengths = self._get_first_step_input( - forward_mode, draft_input, bs, draft_input.input_num_tokens + input_ids, unpadded_input_lengths, gather_ids = self._get_first_step_input( + draft_input, bs, draft_input.input_num_tokens ) - padded_static_len, last_index_offsets = -1, None - if forward_mode.is_decode(): - padded_static_len = self.spec_num_tokens - last_index_offsets = self.last_index_offsets_buf[:bs] - # make a ctx every time model runner forward - first_step_ctx = ForwardContext( + ctx = ForwardContext( attn_backend=self.attn_backend, token_to_kv_pool=self.token_to_kv_pool, req_to_page=self.req_to_page, bs=bs, - num_extends=0, + num_extends=draft_input.num_extends, input_num_tokens=draft_input.input_num_tokens, forward_mode=forward_mode, capture_hidden_mode=CaptureHiddenMode.LAST, - padded_static_len=padded_static_len, - last_index_offsets=last_index_offsets, - keep_full_logits=False, + gather_ids=gather_ids, global_num_tokens=draft_input.global_num_tokens, global_bs=draft_input.global_bs, all_decode_or_idle=draft_input.all_decode_or_idle, ) return self.draft_model_runner.forward( - ctx=first_step_ctx, + ctx=ctx, input_ids=input_ids, positions=buffers.positions_buf[: draft_input.input_num_tokens], out_cache_loc=buffers.out_cache_loc_buf[: draft_input.input_num_tokens], @@ -224,19 +245,17 @@ def _run_multi_step_decode( logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput, ) -> None: - + num_extends = draft_input.num_extends + num_decodes = bs - num_extends req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs] - # Step 1's write position uses vc+accept_length under DECODE so the - # rotary advance doesn't shift past the rejected tail. - if draft_input.forward_mode.is_decode(): - cache_start = ( + cache_start = self.input_buffers.seq_lens_buf[:bs] + if num_decodes > 0: + cache_start[num_extends:] = ( self.runtime_states.valid_cache_lengths.index_select( - 0, req_pool_indices + 0, req_pool_indices[num_extends:] ) - + draft_input.accept_lengths + + draft_input.accept_lengths[num_extends:] ) - else: - cache_start = self.input_buffers.seq_lens_buf[:bs].clone() # Write cache slots for steps 1..N-1. cache_locs = self.draft_out_cache_loc_buf[: bs * (self.spec_num_steps - 1)] @@ -280,7 +299,6 @@ def _run_multi_step_decode( input_num_tokens=bs, forward_mode=ForwardMode.DECODE, capture_hidden_mode=CaptureHiddenMode.LAST, - keep_full_logits=True, global_num_tokens=global_num_tokens, global_bs=draft_input.global_bs, all_decode_or_idle=draft_input.all_decode_or_idle, @@ -313,13 +331,16 @@ def get_candidates( self, base_ctx: ForwardContext, ) -> torch.Tensor | None: - - if not base_ctx.forward_mode.is_decode(): + num_extends = base_ctx.num_extends + num_decodes = base_ctx.bs - num_extends + if num_decodes == 0: return None - return self.input_buffers.input_ids_buf[: base_ctx.input_num_tokens].reshape( - base_ctx.bs, self.spec_num_tokens - ) + num_decode_tokens = num_decodes * self.spec_num_tokens + num_prefill_tokens = base_ctx.input_num_tokens - num_decode_tokens + return self.input_buffers.input_ids_buf[ + num_prefill_tokens : base_ctx.input_num_tokens + ].reshape(num_decodes, self.spec_num_tokens) @override def draft( @@ -338,27 +359,38 @@ def draft( ) # Last verified id per request → next_tokens[:, 0]. - if draft_input.forward_mode.is_extend(): - next_tokens[:, 0] = draft_input.base_model_output[:bs] - else: - indices = self.last_index_offsets_buf[:bs] + draft_input.accept_lengths + num_extends = draft_input.num_extends + num_decodes = bs - num_extends + if num_extends > 0: + next_tokens[:num_extends, 0] = draft_input.base_model_output[:num_extends] + if num_decodes > 0: + indices = ( + self.padded_gather_ids_offsets_buf[:num_decodes] + + draft_input.accept_lengths[num_extends:] + ) + if num_extends > 0: + indices.add_(num_extends) torch.index_select( - draft_input.base_model_output, 0, indices, out=next_tokens[:, 0] + draft_input.base_model_output, + 0, + indices, + out=next_tokens[num_extends:, 0], ) # Seed the draft attn backend's aliased seq_lens for the first step. self.draft_seq_lens_buf[:bs].copy_(self.input_buffers.seq_lens_buf[:bs]) - # First draft step. LogitsProcessor prunes `[bs * spec_num_tokens, ...]` - # down to `[bs, ...]` via padded_static_len, so logits/hidden_states - # arrive here already aligned to one row per request. + # First draft step. LogitsProcessor prunes `[num_prefill_tokens + num_decodes * spec_num_tokens, ...]` + # down to `[bs, ...]`, so logits/hidden_states arrive here already aligned to one row per request. logits_output = self._run_first_step(bs, draft_input) draft_ids = cute_argmax(logits_output.next_token_logits) next_tokens[:, 1] = self._map_hot(draft_ids) - # Draft step 2+ (multi-step decode). - if self.spec_num_steps > 1: + if self.spec_num_steps <= 1: + return next_tokens + + if self.input_buffers.all_extends_mid_chunk and self.dp_size == 1: # Skip multi-step when the whole batch is mid-chunk EXTEND: no # request transitions to target_verify after this forward, so # any speculative tokens we draft would be discarded. @@ -366,12 +398,17 @@ def draft( # In DP we still run, because peer ranks may have completing # extends or decodes; diverging here would desync the drafter's # dense-TP / MoE-EP collectives (NCCL hang or RSAG mismatch). - skip = self.dp_size == 1 and self.input_buffers.all_extends_mid_chunk - if not skip: - self._run_multi_step_decode( - bs, draft_ids, next_tokens, logits_output, draft_input - ) + return next_tokens + # Draft step 2+ (multi-step decode). + # Multi-step decode operates on full bs; drop the [num_extends:] + # slice that step 0 may have set up for MIXED target. No-op on + # backends that fill separate prefill/decode metadata at init + # time. + with self.attn_backend.override_num_extends(0): + self._run_multi_step_decode( + bs, draft_ids, next_tokens, logits_output, draft_input + ) return next_tokens @override @@ -386,6 +423,7 @@ def run( draft_input = EagleDraftInput( input_num_tokens=base_ctx.input_num_tokens, + num_extends=base_ctx.num_extends, forward_mode=base_ctx.forward_mode, base_model_output=output_tokens, accept_lengths=accept_lengths, diff --git a/python/tokenspeed/runtime/execution/input_buffer.py b/python/tokenspeed/runtime/execution/input_buffer.py index 2bae604ec..27c9450d3 100644 --- a/python/tokenspeed/runtime/execution/input_buffer.py +++ b/python/tokenspeed/runtime/execution/input_buffer.py @@ -237,7 +237,7 @@ def fill_input_buffers( torch.where(mask, decode_input_ids_tensor.unsqueeze(1), slot) ) decode_ids = runtime_states.future_input_map[ - decode_req_pool_indices, :1 + decode_req_pool_indices ].flatten() self.input_ids_buf[prefill_token_count:total_tokens].copy_( decode_ids, diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 0f6896beb..7cbb9fac1 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -45,6 +45,7 @@ from tokenspeed.runtime.execution.runtime_states import RuntimeStates from tokenspeed.runtime.execution.types import ModelExecutionResult from tokenspeed.runtime.grammar.capturable_grammar import setup_grammar_step +from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput from tokenspeed.runtime.sampling.backends.base import SamplingBackend from tokenspeed.runtime.sampling.sampling_batch_info import SamplingBatchInfo from tokenspeed.runtime.utils import get_colorful_logger, set_random_seed @@ -351,16 +352,45 @@ def _run_target_forward(self, bs: int, ctx: ForwardContext, req_pool_indices): @nvtx_range("sampling", color="yellow") def _run_sampling( self, - logits_output, + logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo, ctx: ForwardContext, - candidates, + candidates: torch.Tensor | None = None, ): - if self.drafter is not None and ctx.forward_mode.is_decode(): + if self.drafter is None: + return self.sampling_backend.sample(logits_output, sampling_info) + + num_extends = ctx.num_extends + num_decodes = ctx.bs - num_extends + + if num_decodes == 0: + return self.sampling_backend.sample(logits_output, sampling_info) + + if num_extends == 0: return self.sampling_backend.verify( logits_output, sampling_info, candidates ) - return self.sampling_backend.sample(logits_output, sampling_info) + + logits = logits_output.next_token_logits + prefill_out = LogitsProcessorOutput(next_token_logits=logits[:num_extends]) + prefill_tokens, prefill_accept = self.sampling_backend.sample( + prefill_out, sampling_info[:num_extends] + ) + decode_out = LogitsProcessorOutput(next_token_logits=logits[num_extends:]) + decode_tokens, decode_accept = self.sampling_backend.verify( + decode_out, sampling_info[num_extends:], candidates + ) + if ( + prefill_out.next_token_logprobs is not None + and decode_out.next_token_logprobs is not None + ): + logits_output.next_token_logprobs = torch.cat( + [prefill_out.next_token_logprobs, decode_out.next_token_logprobs] + ) + return ( + torch.cat([prefill_tokens, decode_tokens]), + torch.cat([prefill_accept, decode_accept]), + ) @maybe_inference_mode() def _forward_step( @@ -423,7 +453,7 @@ def _update_runtime_state( output_tokens: torch.Tensor, accept_lengths: torch.Tensor, input_lengths: torch.Tensor, - is_extend: bool, + num_extends: int, ): """Write output tokens to future_input_map and update cache lengths. @@ -435,21 +465,24 @@ def _update_runtime_state( # Without drafter, store output tokens for next round. # With drafter, _forward_step already wrote the drafter's # next-round input (verified + draft tokens) to future_input_map. - tokens_per_req = self.config.output_length if not is_extend else 1 + tokens_per_req = self.config.output_length if num_extends == 0 else 1 next_round_input_ids = output_tokens.to(torch.int32).reshape( -1, tokens_per_req ) self.runtime_states.future_input_map[req_pool_indices, :tokens_per_req] = ( next_round_input_ids ) - if is_extend: - self.runtime_states.update_valid_cache_length( - req_pool_indices, input_lengths - ) + + bs = req_pool_indices.shape[0] + if num_extends == 0: + deltas = accept_lengths + elif num_extends == bs: + deltas = input_lengths else: - self.runtime_states.update_valid_cache_length( - req_pool_indices, accept_lengths + deltas = torch.cat( + [input_lengths[:num_extends], accept_lengths[num_extends:]] ) + self.runtime_states.update_valid_cache_length(req_pool_indices, deltas) def _build_sampling_info( self, @@ -956,6 +989,42 @@ def execute_forward_op( output_lengths = torch.zeros(bs, dtype=torch.int32, device=self.device) output_logprobs = None else: + gather_ids = None + if num_extends > 0: + num_decodes = bs - num_extends + if self.drafter is not None and num_decodes > 0: + # MIXED + spec: prefill rows pruned to last token, + # decode block kept full at verify width. + num_decode_tokens = num_decodes * self.config.spec_num_tokens + num_prefill_tokens = total_tokens - num_decode_tokens + gather_ids = torch.empty( + num_extends + num_decode_tokens, + dtype=torch.int64, + device=self.device, + ) + gather_ids[:num_extends] = ( + torch.cumsum( + self.input_buffers.input_lengths_buf[:num_extends], + dim=0, + ) + - 1 + ) + gather_ids[num_extends:] = torch.arange( + num_prefill_tokens, + total_tokens, + device=self.device, + dtype=torch.int64, + ) + else: + # EXTEND, MIXED non-spec, or EXTEND + spec: last token + # per request via cumsum. + gather_ids = ( + torch.cumsum( + self.input_buffers.input_lengths_buf[:bs], dim=0 + ) + - 1 + ) + ctx = ForwardContext( attn_backend=self.attn_backend, token_to_kv_pool=self.token_to_kv_pool, @@ -969,8 +1038,7 @@ def execute_forward_op( if self.drafter is not None else CaptureHiddenMode.NULL ), - padded_static_len=-1, - keep_full_logits=forward_mode.is_decode_or_idle(), + gather_ids=gather_ids, ) if self.config.data_parallel_size > 1: if dp_global_num_tokens is None: @@ -987,7 +1055,7 @@ def execute_forward_op( grammar_completion = setup_grammar_step( sampling_info=sampling_info, bs=bs, - is_spec_decode=self.drafter is not None and num_extends <= 0, + is_spec_decode=self.drafter is not None and num_extends < bs, spec_num_tokens=self.config.spec_num_tokens or 1, grammar_inputs=grammar_inputs, grammar_runtime=self.grammar_runtime, @@ -1075,7 +1143,7 @@ def execute_forward_op( output_tokens=output_tokens, accept_lengths=output_lengths, input_lengths=self.input_buffers.input_lengths_buf[:bs], - is_extend=num_extends > 0, + num_extends=num_extends, ) self._snapshot_mamba_checkpoints( output_lengths, diff --git a/python/tokenspeed/runtime/layers/attention/backends/base.py b/python/tokenspeed/runtime/layers/attention/backends/base.py index d73c666e6..827169b92 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/base.py +++ b/python/tokenspeed/runtime/layers/attention/backends/base.py @@ -21,6 +21,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import contextmanager from typing import TYPE_CHECKING import torch @@ -46,6 +47,18 @@ def __init__(self, config: BaseAttnConfig) -> None: self.dtype = config.dtype self.head_dim = config.head_dim self.is_draft = config.is_draft + self.spec_num_tokens = config.speculative_num_draft_tokens + + @contextmanager + def override_num_extends(self, num_extends: int): + """Temporarily override the decode-metadata slice discriminator for the + wrapped block. Used by MLA backends to flip between drafter step 0 + (slice = [num_extends:]) and step 1+ (slice = [0:]). + + Default no-op for backends that fill separate prefill/decode metadata + at init time. + """ + yield @property def support_kv_cache_prewrite(self) -> bool: @@ -73,7 +86,6 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -84,7 +96,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, diff --git a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py index cffac137a..671001574 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py @@ -13,6 +13,8 @@ from __future__ import annotations +from typing import Optional + import torch from tokenspeed_kernel.ops.attention.flash_mla import ( flash_mla_sparse_fwd, @@ -358,7 +360,6 @@ def _query_lens_cpu( def init_forward_metadata( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -374,7 +375,7 @@ def init_forward_metadata( ) num_extends_arg = kwargs.pop("num_extends", None) num_extends = bs if num_extends_arg is None else int(num_extends_arg) - del num_tokens, kwargs + del kwargs device = seq_lens.device req_pool_indices = req_pool_indices[:bs] seq_lens = seq_lens[:bs].to(torch.int32) @@ -1435,7 +1436,6 @@ def _refresh_cuda_graph_base_offsets( def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -1445,7 +1445,7 @@ def init_forward_metadata_capture_cuda_graph( paged_cache_block_table_base_offsets = ( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) - del num_tokens, kwargs + del kwargs if forward_mode is not None and not forward_mode.is_decode_or_idle(): raise NotImplementedError( f"DeepSeek V4 CUDA graph capture not supported for {forward_mode}" @@ -1513,7 +1513,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, diff --git a/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py b/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py index a44cfda89..a9d48a9d0 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py +++ b/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py @@ -380,7 +380,6 @@ def _update_decode_cuda_graph_scheduler_metadata( def init_forward_metadata( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = ForwardMode.DECODE, @@ -403,14 +402,15 @@ def init_forward_metadata( assert req_to_page is not None, "req_to_page must be provided" - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) # Use max_context_len as worst-case max_seq_len_k — avoids GPU sync (.item()). @@ -419,7 +419,7 @@ def init_forward_metadata( req_pool_indices, req_to_page, self.page_size, max_context_len ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: # Draft Decode if spec_info is not None: if self.topk <= 1: @@ -579,21 +579,21 @@ def init_forward_metadata( if extend_with_prefix and extend_prefix_lens is not None: extend_seq_lens = seq_lens - extend_prefix_lens - # The FA3 workspace is sized from max_seq_len_q. For prefix - # chunks, num_tokens is only the sum over the batch and can be - # much larger than any single query sequence. + # The FA3 workspace is sized from max_seq_len_q. The wrapper's + # padded upper bound is bs * spec_num_tokens; tighter is the + # actual extend_seq_lens_cpu.max() when available. extend_seq_lens_cpu = kwargs.get("extend_seq_lens_cpu") if extend_seq_lens_cpu is not None: metadata.max_seq_len_q = int( extend_seq_lens_cpu[:batch_size].max().item() ) else: - metadata.max_seq_len_q = num_tokens + metadata.max_seq_len_q = batch_size * self.spec_num_tokens metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) ) elif is_draft_extend and kwargs.get("extend_seq_lens") is not None: - metadata.max_seq_len_q = num_tokens + metadata.max_seq_len_q = batch_size * self.spec_num_tokens metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(kwargs["extend_seq_lens"], dim=0, dtype=torch.int32), (1, 0), @@ -619,9 +619,12 @@ def init_forward_metadata( # Route to prefill/decode slot. Drafter's first multi-token step uses # the prefill slot; follow-up single-token steps use the decode slot. - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: self.forward_decode_metadata = metadata - elif is_draft_extend: + elif is_draft_extend or (self.is_draft and forward_mode.is_extend_or_mixed()): + # Drafter: also fill decode slot so step 1+ multi-step has metadata + # under EXTEND/MIXED target. seqlens_in_batch aliases the drafter's + # live buffer (wrapper pre-writes it). self.forward_prefill_metadata = metadata decode_metadata = FlashAttentionMetadata() decode_metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) @@ -630,6 +633,15 @@ def init_forward_metadata( 0, batch_size + 1, dtype=torch.int32, device=device ) decode_metadata.page_table = page_table + # Match the pre-cleanup "Normal Decode" path which always called + # _init_local_attn_metadata; required for chunked-attention models. + self._init_local_attn_metadata( + decode_metadata, + device, + cu_seqlens_q=decode_metadata.cu_seqlens_q, + cache_seqlens_int32=decode_metadata.cache_seqlens_int32, + page_table=page_table, + ) self.forward_decode_metadata = decode_metadata else: self.forward_prefill_metadata = metadata @@ -676,18 +688,18 @@ def forward_extend( # Use precomputed metadata across all layers metadata = self.forward_prefill_metadata - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 is_target_verify = ( forward_mode is not None and forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and q_len_per_req > 1 ) is_draft_extend = ( forward_mode is not None and forward_mode.is_decode_or_idle() and self.is_draft - and spec_num_tokens > 1 + and q_len_per_req > 1 ) # Calculate window size @@ -966,8 +978,8 @@ def forward_decode( ) -> torch.Tensor: # Multi-token decode (target verify or drafter's first post-verify # step) reuses the multi-token prefill path. - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 - if spec_num_tokens > 1: + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 + if q_len_per_req > 1: return self.forward_extend( q, k, @@ -1434,7 +1446,6 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -1447,17 +1458,18 @@ def init_forward_metadata_capture_cuda_graph( metadata_expand = FlashAttentionMetadata() device = seq_lens.device - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: if spec_info is not None: # Draft Decode if self.topk <= 1: @@ -1600,7 +1612,7 @@ def init_forward_metadata_capture_cuda_graph( :bs ] - num_tokens_per_bs = num_tokens // bs + num_tokens_per_bs = self.spec_num_tokens metadata.max_seq_len_q = num_tokens_per_bs metadata.max_seq_len_k = self.max_context_len @@ -1631,7 +1643,7 @@ def init_forward_metadata_capture_cuda_graph( self.decode_cuda_graph_metadata[bs] = decode_metadata # Route to prefill/decode slots. Drafter's compound case populates both. - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: self.forward_decode_metadata = metadata elif is_target_verify: self.forward_prefill_metadata = metadata @@ -1643,7 +1655,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -1662,14 +1673,15 @@ def init_forward_metadata_replay_cuda_graph( assert req_to_page is not None, "req_to_page must be provided" max_context_len = self.max_context_len - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) if ( diff --git a/python/tokenspeed/runtime/layers/attention/backends/flashmla.py b/python/tokenspeed/runtime/layers/attention/backends/flashmla.py index db2976fb1..a708ff462 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/flashmla.py +++ b/python/tokenspeed/runtime/layers/attention/backends/flashmla.py @@ -20,6 +20,7 @@ from __future__ import annotations +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING @@ -60,20 +61,11 @@ @dataclass class FlashMLADecodeMetadata: + num_extends: int = 0 flashmla_metadata: tuple | None = None num_splits: torch.Tensor | None = None block_table: torch.Tensor | None = None - def __init__( - self, - flashmla_metadata=None, - num_splits=None, - block_table=None, - ): - self.flashmla_metadata = flashmla_metadata - self.num_splits = num_splits - self.block_table = block_table - @dataclass class _PrefillMetadata: @@ -170,8 +162,11 @@ def __init__(self, config: MLAConfig): ) self.indices_updater_prefill = _PrefillIndicesUpdater(config, self) - # Metadata state - self.forward_metadata: FlashMLADecodeMetadata | _PrefillMetadata | None = None + # Metadata state. Decode and prefill metadata are split so MIXED batches + # can carry both simultaneously (decode-half + prefill-half sub-contexts + # dispatch to their respective metadata). + self.forward_decode_metadata: FlashMLADecodeMetadata | None = None + self.forward_prefill_metadata: _PrefillMetadata | None = None self.chunked_prefill_metadata: _ChunkedPrefillMetadata | None = None self.last_seq_lens_sum: int | None = None @@ -182,7 +177,7 @@ def __init__(self, config: MLAConfig): def init_forward_metadata( self, bs: int, - num_tokens: int, + num_extends: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -191,110 +186,148 @@ def init_forward_metadata( extend_prefix_lens: torch.Tensor | None = None, spec_info=None, **kwargs, + ): + if forward_mode.is_extend_or_mixed(): + self._init_prefill_metadata( + req_pool_indices=req_pool_indices[:num_extends], + seq_lens=seq_lens[:num_extends], + req_to_page=req_to_page, + extend_with_prefix=extend_with_prefix, + extend_prefix_lens=extend_prefix_lens, + extend_prefix_lens_cpu=kwargs.pop("extend_prefix_lens_cpu"), + extend_seq_lens=kwargs.pop("extend_seq_lens"), + extend_seq_lens_cpu=kwargs.pop("extend_seq_lens_cpu"), + ) + # Under is_draft, also fill decode_metadata under any forward_mode so + # the drafter's multi-step loop has metadata. Wrapper pre-writes + # draft_seq_lens before calling here, so `seq_lens` aliases the + # drafter's live buffer for step-1+ advances. + if ( + forward_mode.is_decode_or_idle() + or forward_mode.is_mixed() + or (forward_mode.is_extend() and self.is_draft) + ): + self._init_decode_metadata( + bs, num_extends, req_pool_indices, seq_lens, req_to_page + ) + + @contextmanager + def override_num_extends(self, num_extends: int): + assert self.forward_decode_metadata is not None + prev = self.forward_decode_metadata.num_extends + self.forward_decode_metadata.num_extends = num_extends + try: + yield + finally: + self.forward_decode_metadata.num_extends = prev + + def _init_decode_metadata( + self, + bs: int, + num_extends: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + req_to_page: torch.Tensor, ): if req_to_page is not None: block_table = req_to_page[req_pool_indices] else: block_table = None - spec_num_tokens = num_tokens // bs if bs > 0 else 1 - is_target_verify = ( - forward_mode.is_decode_or_idle() - and not self.is_draft - and spec_num_tokens > 1 + # When spec-dec is active (self.spec_num_tokens > 1), advance per-row + # seq_lens by the worst-case verify width so the tile planner covers + # the longest path. + if self.spec_num_tokens > 1: + plan_seq_lens = seq_lens + self.draft_token_num + num_heads_plan = self.draft_token_num * self.num_q_heads + else: + plan_seq_lens = seq_lens + num_heads_plan = self.num_q_heads + + mla_metadata, num_splits = get_mla_metadata( + plan_seq_lens.to(torch.int32), + num_heads_plan, + 1, ) - is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + self.forward_decode_metadata = FlashMLADecodeMetadata( + num_extends=num_extends, + flashmla_metadata=mla_metadata, + num_splits=num_splits, + block_table=block_table, ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: - mla_metadata, num_splits = get_mla_metadata( - seq_lens.to(torch.int32), - self.num_q_heads, - 1, - ) - self.forward_metadata = FlashMLADecodeMetadata( - mla_metadata, - num_splits, - block_table, - ) - elif is_target_verify or is_draft_extend: - seq_lens = seq_lens + self.draft_token_num - mla_metadata, num_splits = get_mla_metadata( - seq_lens.to(torch.int32), - self.draft_token_num * self.num_q_heads, - 1, - ) - self.forward_metadata = FlashMLADecodeMetadata( - mla_metadata, - num_splits, - block_table, + def _init_prefill_metadata( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + req_to_page: torch.Tensor, + extend_with_prefix: bool, + extend_prefix_lens: torch.Tensor | None, + extend_prefix_lens_cpu: torch.Tensor, + extend_seq_lens: torch.Tensor, + extend_seq_lens_cpu: torch.Tensor, + ): + # EXTEND path — flashinfer ragged/paged prefill. + if extend_prefix_lens is None: + raise RuntimeError( + "FlashMLABackend.init_forward_metadata requires " + "extend_prefix_lens in extend mode." ) - else: - # EXTEND path — flashinfer ragged/paged prefill. - if extend_prefix_lens is None: - raise RuntimeError( - "FlashMLABackend.init_forward_metadata requires " - "extend_prefix_lens in extend mode." - ) - seq_lens_cpu = seq_lens.cpu() - seq_lens_sum = seq_lens_cpu.sum().item() - self.last_seq_lens_sum = seq_lens_sum + seq_lens_cpu = seq_lens.cpu() + seq_lens_sum = seq_lens_cpu.sum().item() + self.last_seq_lens_sum = seq_lens_sum - extend_no_prefix = not extend_with_prefix - use_ragged = ( - not global_server_args_dict["mla_disable_ragged"] and extend_no_prefix - ) + extend_no_prefix = not extend_with_prefix + use_ragged = ( + not global_server_args_dict["mla_disable_ragged"] and extend_no_prefix + ) - self.indices_updater_prefill.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - extend_prefix_lens, - req_to_page=req_to_page, - prefill_wrapper_paged=self.prefill_wrapper_paged, - use_ragged=use_ragged, - ) - self.forward_metadata = _PrefillMetadata( - self.prefill_wrapper_paged, use_ragged - ) + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + extend_prefix_lens, + req_to_page=req_to_page, + prefill_wrapper_paged=self.prefill_wrapper_paged, + use_ragged=use_ragged, + ) + self.forward_prefill_metadata = _PrefillMetadata( + self.prefill_wrapper_paged, use_ragged + ) - extend_seq_lens = kwargs.pop("extend_seq_lens") - extend_seq_lens_cpu = kwargs.pop("extend_seq_lens_cpu") - extend_prefix_lens_cpu = kwargs.pop("extend_prefix_lens_cpu") - num_extends = extend_seq_lens.shape[0] - cum_extend_seq_lens = torch.zeros( - num_extends + 1, device=self.device, dtype=torch.int32 - ) - torch.cumsum(extend_seq_lens, dim=0, out=cum_extend_seq_lens[1:]) - max_extend_seq_len = extend_seq_lens_cpu.max().item() - ( - chunked_loop_num, - chunk_kv_indices_list, - chunked_seq_len, - cu_chunked_seq_len, - max_chunk_len_per_loop, - ) = build_chunked_prefill_metadata_arrays( - extend_prefix_lens, - extend_prefix_lens_cpu, - req_to_page, - req_pool_indices, - PAGE_SIZE, - ) - self.chunked_prefill_metadata = _ChunkedPrefillMetadata( - extend_prefix_lens=extend_prefix_lens, - extend_prefix_lens_cpu=extend_prefix_lens_cpu, - extend_seq_lens=extend_seq_lens, - extend_seq_lens_cpu=extend_seq_lens_cpu, - req_pool_indices=req_pool_indices, - cum_extend_seq_lens=cum_extend_seq_lens, - max_extend_seq_len=max_extend_seq_len, - chunked_loop_num=chunked_loop_num, - chunk_kv_indices_list=chunk_kv_indices_list, - chunked_seq_len=chunked_seq_len, - cu_chunked_seq_len=cu_chunked_seq_len, - max_chunk_len_per_loop=max_chunk_len_per_loop, - ) + num_extends = extend_seq_lens.shape[0] + cum_extend_seq_lens = torch.zeros( + num_extends + 1, device=self.device, dtype=torch.int32 + ) + torch.cumsum(extend_seq_lens, dim=0, out=cum_extend_seq_lens[1:]) + max_extend_seq_len = extend_seq_lens_cpu.max().item() + ( + chunked_loop_num, + chunk_kv_indices_list, + chunked_seq_len, + cu_chunked_seq_len, + max_chunk_len_per_loop, + ) = build_chunked_prefill_metadata_arrays( + extend_prefix_lens, + extend_prefix_lens_cpu, + req_to_page, + req_pool_indices, + PAGE_SIZE, + ) + self.chunked_prefill_metadata = _ChunkedPrefillMetadata( + extend_prefix_lens=extend_prefix_lens, + extend_prefix_lens_cpu=extend_prefix_lens_cpu, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + req_pool_indices=req_pool_indices, + cum_extend_seq_lens=cum_extend_seq_lens, + max_extend_seq_len=max_extend_seq_len, + chunked_loop_num=chunked_loop_num, + chunk_kv_indices_list=chunk_kv_indices_list, + chunked_seq_len=chunked_seq_len, + cu_chunked_seq_len=cu_chunked_seq_len, + max_chunk_len_per_loop=max_chunk_len_per_loop, + ) # ------------------------------------------------------------------ # CUDA graph (decode only, any q_len) @@ -338,23 +371,23 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, ): block_table = self.cuda_graph_kv_indices[:bs] - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), self.num_q_heads, @@ -363,10 +396,11 @@ def init_forward_metadata_capture_cuda_graph( self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_kv_indices[:bs].copy_(block_table) - self.forward_metadata = FlashMLADecodeMetadata( - self.cuda_graph_mla_metadata, - self.cuda_graph_num_splits[: bs + 1], - self.cuda_graph_kv_indices[:bs, :], + self.forward_decode_metadata = FlashMLADecodeMetadata( + num_extends=0, + flashmla_metadata=self.cuda_graph_mla_metadata, + num_splits=self.cuda_graph_num_splits[: bs + 1], + block_table=self.cuda_graph_kv_indices[:bs, :], ) elif is_target_verify or is_draft_extend: seq_lens = seq_lens + self.draft_token_num @@ -378,10 +412,11 @@ def init_forward_metadata_capture_cuda_graph( self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_kv_indices[:bs].copy_(block_table) - self.forward_metadata = FlashMLADecodeMetadata( - self.cuda_graph_mla_metadata, - self.cuda_graph_num_splits[: bs + 1], - self.cuda_graph_kv_indices[:bs], + self.forward_decode_metadata = FlashMLADecodeMetadata( + num_extends=0, + flashmla_metadata=self.cuda_graph_mla_metadata, + num_splits=self.cuda_graph_num_splits[: bs + 1], + block_table=self.cuda_graph_kv_indices[:bs], ) else: raise RuntimeError(f"Not supported forward mode: {forward_mode}") @@ -389,7 +424,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -406,11 +440,10 @@ def init_forward_metadata_replay_cuda_graph( block_table = self.cuda_graph_kv_indices[:bs] seq_lens = seq_lens[:bs] - spec_num_tokens = num_tokens // bs if bs > 0 else 1 - is_target_verify = not self.is_draft and spec_num_tokens > 1 - is_draft_extend = self.is_draft and spec_num_tokens > 1 + is_target_verify = not self.is_draft and self.spec_num_tokens > 1 + is_draft_extend = self.is_draft and self.spec_num_tokens > 1 - if spec_num_tokens == 1: + if self.spec_num_tokens == 1: mla_metadata, num_splits = get_mla_metadata( seq_lens.to(torch.int32), self.num_q_heads, @@ -429,9 +462,10 @@ def init_forward_metadata_replay_cuda_graph( self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_kv_indices[:bs].copy_(block_table) - self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata - self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] - self.forward_metadata.block_table = self.cuda_graph_kv_indices[:bs] + self.forward_decode_metadata.num_extends = 0 + self.forward_decode_metadata.flashmla_metadata = self.cuda_graph_mla_metadata + self.forward_decode_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1] + self.forward_decode_metadata.block_table = self.cuda_graph_kv_indices[:bs] def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -454,23 +488,23 @@ def forward_extend( forward_mode: ForwardMode | None = None, **kwargs, ): - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 is_target_verify = ( forward_mode is not None and forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and q_len_per_req > 1 ) is_draft_extend = ( forward_mode is not None and forward_mode.is_decode_or_idle() and self.is_draft - and spec_num_tokens > 1 + and q_len_per_req > 1 ) if forward_mode is None or forward_mode.is_extend(): # Prefill: dispatch to ragged (MHA-style) or absorbed (MQA) path. - if self.forward_metadata.use_ragged: + if self.forward_prefill_metadata.use_ragged: return self._forward_normal_extend(q, k, v, layer, save_kv_cache) else: return self._forward_absorbed_extend( @@ -489,10 +523,12 @@ def forward_extend( if save_kv_cache: token_to_kv_pool.set_kv_buffer(layer, out_cache_loc, k, v) + metadata = self.forward_decode_metadata + num_extends = metadata.num_extends bs = ( q.shape[0] if is_draft_extend - else self.forward_metadata.block_table.shape[0] + else metadata.block_table.shape[0] - num_extends ) k_cache = token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -504,11 +540,11 @@ def forward_extend( o, _ = flash_mla_with_kvcache( q=reshape_q, k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), - block_table=self.forward_metadata.block_table[:bs], + block_table=metadata.block_table[num_extends : num_extends + bs], cache_seqlens=seq_lens.to(torch.int32) + self.draft_token_num, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, - num_splits=self.forward_metadata.num_splits, + tile_scheduler_metadata=metadata.flashmla_metadata, + num_splits=metadata.num_splits, softmax_scale=layer.scaling, causal=True, ) @@ -529,12 +565,15 @@ def forward_extend_chunked( seq_lens, batch_size, causal, + out: torch.Tensor | None = None, ): if causal: step_counter = getattr(self, "step_counter", None) if step_counter is not None: step_counter.record_cache() head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + # flash_attn_varlen_func has no `out=` parameter; copy into the + # caller-provided buffer at the end when requested. output, lse, *_ = flash_attn_varlen_func( q=q.view(-1, self.num_local_heads, head_dim), k=k.view(-1, self.num_local_heads, head_dim).to(q.dtype), @@ -547,6 +586,9 @@ def forward_extend_chunked( causal=causal, return_attn_probs=True, ) + if out is not None: + out.copy_(output.view(out.shape)) + output = out # lse must be transposed when using fa3. return output, lse.T.contiguous() @@ -565,8 +607,8 @@ def forward_decode( ) -> torch.Tensor: # Multi-token decode (target verify or drafter compound) reuses # the multi-token kernel path in forward_extend. - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 - if spec_num_tokens > 1: + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 + if q_len_per_req > 1: return self.forward_extend( q, k, @@ -591,6 +633,8 @@ def forward_decode( v, ) bs = q.shape[0] + metadata = self.forward_decode_metadata + num_extends = metadata.num_extends k_cache = token_to_kv_pool.get_key_buffer(layer.layer_id) assert ( layer.tp_q_head_num == self.num_q_heads @@ -601,11 +645,11 @@ def forward_decode( o, _ = flash_mla_with_kvcache( q=reshape_q, k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), - block_table=self.forward_metadata.block_table[:bs], + block_table=metadata.block_table[num_extends : num_extends + bs], cache_seqlens=cache_lens.to(torch.int32), head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=self.forward_metadata.flashmla_metadata, - num_splits=self.forward_metadata.num_splits, + tile_scheduler_metadata=metadata.flashmla_metadata, + num_splits=metadata.num_splits, softmax_scale=layer.scaling, causal=True, ) @@ -665,7 +709,7 @@ def _forward_absorbed_extend( o = q_nope.new_empty(q_nope.shape) k_buf = token_to_kv_pool.get_key_buffer(layer.layer_id).to(q_nope.dtype) - o = self.forward_metadata.prefill_wrapper.run( + o = self.forward_prefill_metadata.prefill_wrapper.run( q_nope, q_pe, k_buf[:, :, : layer.v_head_dim], diff --git a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py index a8b5133ac..b603eef1d 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py +++ b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py @@ -328,7 +328,6 @@ def reset_current_inputs( def init_forward_metadata( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = ForwardMode.DECODE, @@ -340,14 +339,15 @@ def init_forward_metadata( else: mamba_cache_indices = self.pool.get_mamba_indices(req_pool_indices[:bs]) - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) mamba_output_indices = None @@ -366,7 +366,7 @@ def init_forward_metadata( ) mamba_cache_indices = mamba_input_indices - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: query_start_loc = torch.arange( 0, bs + 1, dtype=torch.int32, device=self.device ) @@ -377,7 +377,7 @@ def init_forward_metadata( ) query_start_loc = torch.arange( 0, - num_tokens + 1, + bs * tokens_per_req + 1, step=tokens_per_req, dtype=torch.int32, device=self.device, @@ -578,23 +578,23 @@ def init_cuda_graph_state( def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, **kwargs, ): - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_decode_query_start_loc[: bs + 1] ) @@ -632,7 +632,7 @@ def init_forward_metadata_capture_cuda_graph( ) padded_mamba_indices.copy_(mamba_input_indices) self._qsl_dirty[bs - 1] = False - self._qsl_last_mode[bs - 1] = (forward_mode, spec_num_tokens > 1) + self._qsl_last_mode[bs - 1] = (forward_mode, self.spec_num_tokens > 1) self.forward_metadata = MambaForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], @@ -643,7 +643,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -672,18 +671,17 @@ def init_forward_metadata_replay_cuda_graph( if num_padding > 0: padded_mamba_indices[real_bs:].fill_(-1) - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode is not None and forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( forward_mode is not None and forward_mode.is_decode_or_idle() and self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) mamba_output_indices = None @@ -705,10 +703,10 @@ def init_forward_metadata_replay_cuda_graph( if num_padding == 0: need_copy = self._qsl_dirty[bs - 1] or self._qsl_last_mode[bs - 1] != ( forward_mode, - spec_num_tokens > 1, + self.spec_num_tokens > 1, ) if need_copy: - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_decode_query_start_loc[: bs + 1] ) @@ -717,9 +715,9 @@ def init_forward_metadata_replay_cuda_graph( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] ) self._qsl_dirty[bs - 1] = False - self._qsl_last_mode[bs - 1] = (forward_mode, spec_num_tokens > 1) + self._qsl_last_mode[bs - 1] = (forward_mode, self.spec_num_tokens > 1) else: - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: self.query_start_loc_list[bs - 1][:real_bs].copy_( self.cached_cuda_graph_decode_query_start_loc[:real_bs] ) @@ -734,7 +732,7 @@ def init_forward_metadata_replay_cuda_graph( else: raise ValueError(f"Invalid forward mode: {forward_mode=}") self._qsl_dirty[bs - 1] = True - self._qsl_last_mode[bs - 1] = (forward_mode, spec_num_tokens > 1) + self._qsl_last_mode[bs - 1] = (forward_mode, self.spec_num_tokens > 1) self.forward_metadata = MambaForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], @@ -763,8 +761,8 @@ def forward_decode( # Multi-token decode (target verify or drafter compound) reuses # the multi-token kernel path in forward_extend. `q` is None for # hybrid linear-attn layers; the token count comes from mixed_qkv. - spec_num_tokens = kwargs["mixed_qkv"].shape[0] // bs if bs > 0 else 1 - if spec_num_tokens > 1: + q_len_per_req = kwargs["mixed_qkv"].shape[0] // bs if bs > 0 else 1 + if q_len_per_req > 1: return self.forward_extend( q, k, @@ -869,12 +867,12 @@ def forward_extend( # `q` is None for hybrid linear-attn layers; the token count comes # from seq_len carried in kwargs. - spec_num_tokens = seq_len // bs if bs > 0 else 1 + q_len_per_req = seq_len // bs if bs > 0 else 1 is_target_verify = ( forward_mode is not None and forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and q_len_per_req > 1 ) query_start_loc = self.forward_metadata.query_start_loc diff --git a/python/tokenspeed/runtime/layers/attention/backends/mha.py b/python/tokenspeed/runtime/layers/attention/backends/mha.py index 6ff7bc3e3..236009771 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/mha.py +++ b/python/tokenspeed/runtime/layers/attention/backends/mha.py @@ -100,7 +100,6 @@ def __init__(self, config: MHAConfig): def init_forward_metadata( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -160,25 +159,36 @@ def init_forward_metadata( max_prefix_seq_len=max_prefix_seq_len, has_prefix=has_prefix, ) + if not self.is_draft: + return + # Drafter: also fill decode_metadata so step 1+ multi-step has + # metadata under EXTEND/MIXED target. seq_lens is the drafter's + # live alias buffer (wrapper pre-writes it before this call). + self.forward_decode_metadata = MHAMetadata( + cache_seqlens_int32=seq_lens, + page_table=page_table, + max_seq_len_k=self.max_context_len, + ) return - spec_num_tokens = num_tokens // bs if bs > 0 else 1 - if spec_num_tokens > 1: + if self.spec_num_tokens > 1: self.forward_prefill_metadata = MHAMetadata( cache_seqlens_int32=seq_lens, cu_seqlens_q=self._make_uniform_cu_seqlens( bs, - spec_num_tokens, + self.spec_num_tokens, seq_lens.device, ), page_table=page_table, - max_seq_len_q=spec_num_tokens, + max_seq_len_q=self.spec_num_tokens, max_seq_len_k=self.max_context_len, ) if self.is_draft: # Drafter follow-up single-token steps after the first. + # cache_seqlens_int32 aliases seq_lens (drafter's live buffer) + # so multi-step in-place advances propagate to the kernel. self.forward_decode_metadata = MHAMetadata( - cache_seqlens_int32=seq_lens.clone(), + cache_seqlens_int32=seq_lens, page_table=page_table, max_seq_len_k=self.max_context_len, ) @@ -235,7 +245,6 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -246,17 +255,16 @@ def init_forward_metadata_capture_cuda_graph( ) cache_seqlens = self.cuda_graph_cache_seqlens[:bs] - spec_num_tokens = num_tokens // bs if bs > 0 else 1 - if spec_num_tokens > 1: + if self.spec_num_tokens > 1: metadata = MHAMetadata( cache_seqlens_int32=cache_seqlens, cu_seqlens_q=self._make_uniform_cu_seqlens( bs, - spec_num_tokens, + self.spec_num_tokens, self.device, ), page_table=self.cuda_graph_page_table[:bs, :], - max_seq_len_q=spec_num_tokens, + max_seq_len_q=self.spec_num_tokens, max_seq_len_k=self.max_context_len, ) self.cuda_graph_prefill_metadata[bs] = metadata @@ -281,7 +289,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -321,8 +328,8 @@ def forward_decode( # Multi-token decode (q_len > 1) reuses the prefill kernel via the # uniform-stride prefill slot; plain decode uses the single-token slot. - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 - if spec_num_tokens > 1: + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 + if q_len_per_req > 1: return self.forward_extend( q, k, diff --git a/python/tokenspeed/runtime/layers/attention/backends/tokenspeed_mla.py b/python/tokenspeed/runtime/layers/attention/backends/tokenspeed_mla.py index 4df8d8808..aed3c4cd3 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/tokenspeed_mla.py +++ b/python/tokenspeed/runtime/layers/attention/backends/tokenspeed_mla.py @@ -29,6 +29,7 @@ from __future__ import annotations import logging +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING @@ -100,6 +101,7 @@ class CuteDSLMLAPrefillMetadata: @dataclass class CuteDSLMLADecodeMetadata: + num_extends: int = 0 block_kv_indices: torch.Tensor | None = None max_seq_len_k: int | None = None seq_lens_k: torch.Tensor | None = None @@ -213,7 +215,7 @@ def _create_block_kv_indices( def init_forward_metadata( self, bs: int, - num_tokens: int, + num_extends: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -224,41 +226,60 @@ def init_forward_metadata( ): if forward_mode.is_extend_or_mixed(): self._init_prefill_metadata( - seq_lens, - req_pool_indices=req_pool_indices, + seq_lens[:num_extends], + req_pool_indices=req_pool_indices[:num_extends], req_to_page=req_to_page, extend_prefix_lens=kwargs.pop("extend_prefix_lens"), extend_prefix_lens_cpu=kwargs.pop("extend_prefix_lens_cpu"), extend_seq_lens=kwargs.pop("extend_seq_lens"), extend_seq_lens_cpu=kwargs.pop("extend_seq_lens_cpu"), ) - else: + # Drafter steps 1..N are pure DECODE on full bs regardless of target + # mode, so under is_draft we also fill decode_metadata under EXTEND + # so the multi-step loop has metadata. The wrapper pre-writes + # draft_seq_lens before calling here so `seq_lens` aliases the + # drafter's live buffer. + if ( + forward_mode.is_decode() + or forward_mode.is_mixed() + or (forward_mode.is_extend() and self.is_draft) + ): self._init_decode_metadata( - bs, req_pool_indices, seq_lens, forward_mode, req_to_page, spec_info + bs, + num_extends, + req_pool_indices, + seq_lens, + req_to_page, ) + @contextmanager + def override_num_extends(self, num_extends: int): + assert self.forward_decode_metadata is not None + prev = self.forward_decode_metadata.num_extends + self.forward_decode_metadata.num_extends = num_extends + try: + yield + finally: + self.forward_decode_metadata.num_extends = prev + def _init_decode_metadata( self, bs: int, + num_extends: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - forward_mode: ForwardMode, req_to_page: torch.Tensor, - spec_info=None, ): max_blocks = self._calc_padded_blocks(self.max_context_len) - block_kv_indices = self._create_block_kv_indices( bs, max_blocks, req_pool_indices, seq_lens, req_to_page ) - assert ( - seq_lens.dtype == torch.int32 - ), f"seq_lens must be int32, got {seq_lens.dtype}" self.forward_decode_metadata = CuteDSLMLADecodeMetadata( block_kv_indices=block_kv_indices, max_seq_len_k=self.max_context_len, seq_lens_k=seq_lens, + num_extends=num_extends, ) def _init_prefill_metadata( @@ -344,7 +365,6 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -354,22 +374,21 @@ def init_forward_metadata_capture_cuda_graph( f"tokenspeed_mla CUDA graph capture not supported for {forward_mode}" ) - metadata = CuteDSLMLADecodeMetadata() max_blocks = self._calc_padded_blocks(self.max_context_len) block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks] - metadata.block_kv_indices = block_kv_indices - metadata.max_seq_len_k = self.max_context_len - # seq_lens_k aliases seq_lens_buf (set in init_cuda_graph_state). - metadata.seq_lens_k = self.cuda_graph_seq_lens_buf[:bs] - + metadata = CuteDSLMLADecodeMetadata( + block_kv_indices=block_kv_indices, + max_seq_len_k=self.max_context_len, + seq_lens_k=self.cuda_graph_seq_lens_buf[:bs], + num_extends=0, + ) self.decode_cuda_graph_metadata[bs] = metadata self.forward_decode_metadata = metadata def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -423,8 +442,10 @@ def forward_decode( k[..., self.kv_lora_rank :], ) - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 - query = q.view(bs, spec_num_tokens, layer.tp_q_head_num, layer.head_dim) + metadata = self.forward_decode_metadata + num_extends = metadata.num_extends + q_len_per_req = q.shape[0] // bs + query = q.view(bs, q_len_per_req, layer.tp_q_head_num, layer.head_dim) softmax_scale = layer.scaling if self.data_type == torch.float8_e4m3fn: @@ -442,8 +463,6 @@ def forward_decode( k_cache = k_cache.to(self.data_type) kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim) - metadata = self.forward_decode_metadata - if not CuteDSLMLABackend._logged_decode: logger.info( "CuteDSL MLA decode kernel invoked (tokenspeed_mla_decode, query_dtype=%s, kv_dtype=%s)", @@ -462,8 +481,8 @@ def forward_decode( workspace_buffer=self.cutedsl_workspace, kv_lora_rank=self.kv_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim, - block_tables=metadata.block_kv_indices, - seq_lens=metadata.seq_lens_k, + block_tables=metadata.block_kv_indices[num_extends:], + seq_lens=metadata.seq_lens_k[num_extends:], max_seq_len=metadata.max_seq_len_k, softmax_scale=softmax_scale, enable_pdl=pdl_enabled(), @@ -488,6 +507,7 @@ def forward_extend_chunked( seq_lens, batch_size, causal, + out: torch.Tensor | None = None, ): if causal: step_counter = getattr(self, "step_counter", None) @@ -523,6 +543,7 @@ def forward_extend_chunked( cum_seq_lens_q=cum_seq_lens_q, max_seq_len_q=max_q_len, enable_pdl=pdl_enabled(), + out=out, ) if isinstance(result, tuple): diff --git a/python/tokenspeed/runtime/layers/attention/backends/triton.py b/python/tokenspeed/runtime/layers/attention/backends/triton.py index 3a3f1d179..d63c8afef 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/triton.py +++ b/python/tokenspeed/runtime/layers/attention/backends/triton.py @@ -178,7 +178,6 @@ def get_num_kv_splits( def init_forward_metadata( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -200,17 +199,18 @@ def init_forward_metadata( window_kv_indices = None window_num_kv_splits = None - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: if spec_info is None: torch.cumsum(seq_lens, dim=0, out=kv_indptr[1 : bs + 1]) kv_indptr = kv_indptr[: bs + 1] @@ -451,7 +451,6 @@ def init_cuda_graph_state( def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -462,17 +461,18 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices = None window_num_kv_splits = None - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: if spec_info is None: kv_indptr = self.kv_indptr torch.cumsum(seq_lens, dim=0, out=kv_indptr[1 : bs + 1]) @@ -613,7 +613,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -623,17 +622,18 @@ def init_forward_metadata_replay_cuda_graph( ): _req_to_token = self.req_to_page - spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( forward_mode.is_decode_or_idle() and not self.is_draft - and spec_num_tokens > 1 + and self.spec_num_tokens > 1 ) is_draft_extend = ( - forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 + forward_mode.is_decode_or_idle() + and self.is_draft + and self.spec_num_tokens > 1 ) - if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: + if forward_mode.is_decode_or_idle() and self.spec_num_tokens == 1: # Update kv_indptr, kv_indices kv_indptr = self.kv_indptr kv_indices = self.cuda_graph_kv_indices @@ -811,8 +811,8 @@ def forward_decode( ): # Multi-token decode (target verify or drafter compound) reuses # the multi-token kernel path in forward_extend. - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 - if spec_num_tokens > 1: + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 + if q_len_per_req > 1: return self.forward_extend( q, k, diff --git a/python/tokenspeed/runtime/layers/attention/backends/trtllm.py b/python/tokenspeed/runtime/layers/attention/backends/trtllm.py index 523c12a6a..ef4f82f94 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/trtllm.py +++ b/python/tokenspeed/runtime/layers/attention/backends/trtllm.py @@ -265,10 +265,10 @@ def forward_decode( # Multi-token decode (q_len > 1) reads the prefill slot's # uniform-stride metadata; plain decode reads the single-token slot. - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 metadata = ( self.forward_prefill_metadata - if spec_num_tokens > 1 + if q_len_per_req > 1 else self.forward_decode_metadata ) @@ -337,7 +337,6 @@ def forward_extend( def init_forward_metadata( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -361,12 +360,16 @@ def init_forward_metadata( extend_prefix_lens_cpu=extend_prefix_lens_cpu, extend_seq_lens_cpu=extend_seq_lens_cpu, ) + # Drafter: also fill decode_metadata so step 1+ multi-step has + # metadata under EXTEND/MIXED target. seq_lens is the drafter's + # live alias buffer (wrapper pre-writes before this call). + if self.is_draft: + self._init_decode_metadata(bs, req_pool_indices, seq_lens, req_to_page) return - spec_num_tokens = num_tokens // bs if bs > 0 else 1 - if spec_num_tokens > 1: + if self.spec_num_tokens > 1: self._init_multi_token_metadata( - bs, spec_num_tokens, req_pool_indices, seq_lens, req_to_page + bs, self.spec_num_tokens, req_pool_indices, seq_lens, req_to_page ) if self.is_draft: # Drafter's N-1 single-token steps after the first. @@ -511,7 +514,6 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -521,9 +523,8 @@ def init_forward_metadata_capture_cuda_graph( f"trtllm CUDA graph capture not supported for {forward_mode}" ) - spec_num_tokens = num_tokens // bs if bs > 0 else 1 - if spec_num_tokens > 1: - self._init_multi_token_metadata_capture(bs, spec_num_tokens, seq_lens) + if self.spec_num_tokens > 1: + self._init_multi_token_metadata_capture(bs, self.spec_num_tokens, seq_lens) if self.is_draft: self._init_decode_metadata_capture(bs, seq_lens) else: @@ -564,7 +565,6 @@ def _init_multi_token_metadata_capture( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, diff --git a/python/tokenspeed/runtime/layers/attention/backends/trtllm_mla.py b/python/tokenspeed/runtime/layers/attention/backends/trtllm_mla.py index 5e13937b3..405e6f364 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/trtllm_mla.py +++ b/python/tokenspeed/runtime/layers/attention/backends/trtllm_mla.py @@ -27,6 +27,7 @@ from __future__ import annotations import logging +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING @@ -99,6 +100,7 @@ class TRTLLMMLAChunkedPrefillMetadata: @dataclass class TRTLLMMLADecodeMetadata: + num_extends: int = 0 block_kv_indices: torch.Tensor | None = None max_seq_len_k: int | None = None seq_lens_k: torch.Tensor | None = None @@ -180,7 +182,7 @@ def _create_block_kv_indices( def init_forward_metadata( self, bs: int, - num_tokens: int, + num_extends: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -190,27 +192,44 @@ def init_forward_metadata( ): if forward_mode.is_extend_or_mixed(): self._init_prefill_metadata( - seq_lens, - req_pool_indices=req_pool_indices, + seq_lens[:num_extends], + req_pool_indices=req_pool_indices[:num_extends], req_to_page=req_to_page, extend_prefix_lens=kwargs.pop("extend_prefix_lens"), extend_prefix_lens_cpu=kwargs.pop("extend_prefix_lens_cpu"), extend_seq_lens=kwargs.pop("extend_seq_lens"), extend_seq_lens_cpu=kwargs.pop("extend_seq_lens_cpu"), ) - else: + # Under is_draft, also fill decode_metadata under any forward_mode so + # the drafter's multi-step loop has metadata. Wrapper pre-writes + # draft_seq_lens before calling here, so `seq_lens` aliases the + # drafter's live buffer for step-1+ advances. + if ( + forward_mode.is_decode() + or forward_mode.is_mixed() + or (forward_mode.is_extend() and self.is_draft) + ): self._init_decode_metadata( - bs, req_pool_indices, seq_lens, forward_mode, req_to_page, spec_info + bs, num_extends, req_pool_indices, seq_lens, req_to_page ) + @contextmanager + def override_num_extends(self, num_extends: int): + assert self.forward_decode_metadata is not None + prev = self.forward_decode_metadata.num_extends + self.forward_decode_metadata.num_extends = num_extends + try: + yield + finally: + self.forward_decode_metadata.num_extends = prev + def _init_decode_metadata( self, bs: int, + num_extends: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - forward_mode: ForwardMode, req_to_page: torch.Tensor, - spec_info=None, ): # For target_verify, the draft tokens have already been written to the KV # cache. The seq_lens passed in should already reflect the full context. @@ -225,9 +244,10 @@ def _init_decode_metadata( seq_lens.dtype == torch.int32 ), f"seq_lens must be int32, got {seq_lens.dtype}" self.forward_decode_metadata = TRTLLMMLADecodeMetadata( + num_extends=num_extends, block_kv_indices=block_kv_indices, max_seq_len_k=self.max_context_len, - seq_lens_k=seq_lens[:bs], + seq_lens_k=seq_lens, ) def _init_prefill_metadata( @@ -309,7 +329,6 @@ def init_cuda_graph_state(self, max_bs: int, seq_lens_buf: torch.Tensor): def init_forward_metadata_capture_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode, @@ -319,16 +338,18 @@ def init_forward_metadata_capture_cuda_graph( f"trtllm_mla CUDA graph capture not supported for {forward_mode}" ) - metadata = TRTLLMMLADecodeMetadata() max_blocks = self._calc_padded_blocks(self.max_context_len) block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks] # For capture we don't have req_to_page yet; just zero-fill the block indices. - # The actual indices will be filled on replay. - metadata.block_kv_indices = block_kv_indices - metadata.max_seq_len_k = self.max_context_len - # seq_lens_k aliases seq_lens_buf (set in init_cuda_graph_state). - metadata.seq_lens_k = self.cuda_graph_seq_lens_buf[:bs] + # The actual indices will be filled on replay. seq_lens_k aliases + # seq_lens_buf (set in init_cuda_graph_state). + metadata = TRTLLMMLADecodeMetadata( + num_extends=0, + block_kv_indices=block_kv_indices, + max_seq_len_k=self.max_context_len, + seq_lens_k=self.cuda_graph_seq_lens_buf[:bs], + ) self.decode_cuda_graph_metadata[bs] = metadata self.forward_decode_metadata = metadata @@ -336,7 +357,6 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, forward_mode: ForwardMode = None, @@ -391,26 +411,29 @@ def forward_decode( ) metadata = self.forward_decode_metadata - spec_num_tokens = q.shape[0] // bs if bs > 0 else 1 + num_extends = metadata.num_extends + q_len_per_req = q.shape[0] // bs if bs > 0 else 1 - if spec_num_tokens > 1 and self.is_draft: + if q_len_per_req > 1 and self.is_draft: # First draft step catching up its KV after verify: one query entry per token; # per-token seq_lens advance by 1 so each successive token sees its own KV write. query = q.view(-1, layer.tp_q_head_num, layer.head_dim).unsqueeze(1) - block_tables = metadata.block_kv_indices.repeat_interleave( - spec_num_tokens, dim=0 + block_tables = metadata.block_kv_indices[num_extends:].repeat_interleave( + q_len_per_req, dim=0 + ) + base_lens = metadata.seq_lens_k[num_extends:].repeat_interleave( + q_len_per_req ) - base_lens = metadata.seq_lens_k.repeat_interleave(spec_num_tokens) offsets = torch.arange( - spec_num_tokens, device=base_lens.device, dtype=base_lens.dtype + q_len_per_req, device=base_lens.device, dtype=base_lens.dtype ).repeat(bs) seq_lens = base_lens + offsets - max_seq_len = metadata.max_seq_len_k + spec_num_tokens + max_seq_len = metadata.max_seq_len_k + q_len_per_req else: # Plain decode (q_len=1) or bs-grouped multi-token decode. query = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) - block_tables = metadata.block_kv_indices - seq_lens = metadata.seq_lens_k + block_tables = metadata.block_kv_indices[num_extends:] + seq_lens = metadata.seq_lens_k[num_extends:] max_seq_len = metadata.max_seq_len_k if self.data_type == torch.float8_e4m3fn: @@ -459,6 +482,7 @@ def forward_extend_chunked( seq_lens, batch_size, causal, + out: torch.Tensor | None = None, ): if causal: step_counter = getattr(self, "step_counter", None) @@ -476,18 +500,19 @@ def forward_extend_chunked( k = k.to(torch.float8_e4m3fn) v = v.to(torch.float8_e4m3fn) - # The ragged path does not support FP8 output. - out_dtype = self.q_data_type - if out_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - out_dtype = torch.bfloat16 - - out = torch.empty( - q.shape[0], - q.shape[1], - v.shape[2], - device=q.device, - dtype=out_dtype, - ) + if out is None: + # The ragged path does not support FP8 output. + out_dtype = self.q_data_type + if out_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + out_dtype = torch.bfloat16 + + out = torch.empty( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=out_dtype, + ) result = trtllm_ragged_attention_deepseek( query=q, diff --git a/python/tokenspeed/runtime/layers/attention/configs/base.py b/python/tokenspeed/runtime/layers/attention/configs/base.py index fac41cec8..d6556045a 100644 --- a/python/tokenspeed/runtime/layers/attention/configs/base.py +++ b/python/tokenspeed/runtime/layers/attention/configs/base.py @@ -56,7 +56,7 @@ class BaseAttnConfig: max_graph_bs: int kv_cache_quant_method: str speculative_num_steps: int = 0 - speculative_num_draft_tokens: int = 0 + speculative_num_draft_tokens: int = 1 is_draft: bool = False @classmethod diff --git a/python/tokenspeed/runtime/layers/attention/configs/mha.py b/python/tokenspeed/runtime/layers/attention/configs/mha.py index 414fe387b..22fcb7f42 100644 --- a/python/tokenspeed/runtime/layers/attention/configs/mha.py +++ b/python/tokenspeed/runtime/layers/attention/configs/mha.py @@ -39,6 +39,12 @@ class MHAConfig(BaseAttnConfig): def generate( cls, server_args: ServerArgs, model_config: ModelConfig, is_draft: bool = False ): + kwargs = {} + if server_args.speculative_algorithm is not None: + kwargs.update( + speculative_num_steps=server_args.speculative_num_steps, + speculative_num_draft_tokens=server_args.speculative_num_draft_tokens, + ) return cls( device=server_args.device, context_len=model_config.context_len, @@ -58,9 +64,8 @@ def generate( // (server_args.data_parallel_size or server_args.mapping.attn.dp_size), max_graph_bs=server_args.max_cudagraph_capture_size, kv_cache_quant_method=server_args.kv_cache_quant_method, - speculative_num_steps=server_args.speculative_num_steps, - speculative_num_draft_tokens=server_args.speculative_num_draft_tokens, is_draft=is_draft, + **kwargs, ) def cache_cell_size(self) -> int: diff --git a/python/tokenspeed/runtime/layers/attention/configs/mla.py b/python/tokenspeed/runtime/layers/attention/configs/mla.py index a71e813ce..28ae5929f 100644 --- a/python/tokenspeed/runtime/layers/attention/configs/mla.py +++ b/python/tokenspeed/runtime/layers/attention/configs/mla.py @@ -46,6 +46,12 @@ class MLAConfig(BaseAttnConfig): def generate( cls, server_args: ServerArgs, model_config: ModelConfig, is_draft: bool = False ): + kwargs = {} + if server_args.speculative_algorithm is not None: + kwargs.update( + speculative_num_steps=server_args.speculative_num_steps, + speculative_num_draft_tokens=server_args.speculative_num_draft_tokens, + ) return cls( device=server_args.device, context_len=model_config.context_len, @@ -65,8 +71,6 @@ def generate( max_bs=server_args.max_num_seqs // (server_args.data_parallel_size or server_args.mapping.attn.dp_size), kv_cache_quant_method=server_args.kv_cache_quant_method, - speculative_num_steps=server_args.speculative_num_steps, - speculative_num_draft_tokens=server_args.speculative_num_draft_tokens, is_draft=is_draft, kv_lora_rank=model_config.kv_lora_rank, qk_nope_head_dim=model_config.qk_nope_head_dim, @@ -74,6 +78,7 @@ def generate( v_head_dim=model_config.v_head_dim, scaling=model_config.scaling, kv_cache_dim=model_config.kv_lora_rank + model_config.qk_rope_head_dim, + **kwargs, ) def cache_cell_size(self) -> int: diff --git a/python/tokenspeed/runtime/layers/attention/registry.py b/python/tokenspeed/runtime/layers/attention/registry.py index f32871347..7324bc7c9 100644 --- a/python/tokenspeed/runtime/layers/attention/registry.py +++ b/python/tokenspeed/runtime/layers/attention/registry.py @@ -218,10 +218,13 @@ def _create_hybrid_linear_attn( config, ) - # Create mamba/linear attention backend - config.speculative_num_draft_tokens = getattr( - server_args, "speculative_num_draft_tokens", 0 - ) + # Create mamba/linear attention backend. Only propagate the configured + # verify width when spec-dec is actually enabled — matches MLAConfig / + # MHAConfig.generate. Otherwise the BaseAttnConfig sentinel (1) wins so + # non-spec hybrid decode doesn't get misclassified as target verify / + # draft extend by `self.spec_num_tokens > 1`. + if server_args.speculative_algorithm is not None: + config.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens # Create KV cache pool (only for full attention layers) num_full_attn_layers = len(full_attn_layers) diff --git a/python/tokenspeed/runtime/layers/logits_processor.py b/python/tokenspeed/runtime/layers/logits_processor.py index 039e49e30..0b7a26865 100755 --- a/python/tokenspeed/runtime/layers/logits_processor.py +++ b/python/tokenspeed/runtime/layers/logits_processor.py @@ -73,6 +73,7 @@ class LogitsProcessorOutput: class LogitsMetadata: forward_mode: ForwardMode capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL + gather_ids: torch.Tensor | None = None extend_return_logprob: bool = False extend_return_top_logprob: bool = False @@ -104,10 +105,6 @@ class LogitsMetadata: global_num_tokens_for_logprob_cpu: torch.Tensor | None = None global_num_tokens_for_logprob_gpu: torch.Tensor | None = None - # for padding - padded_static_len: int = -1 - last_index_offsets: torch.Tensor | None = None - @classmethod def from_forward_context( cls, @@ -117,9 +114,8 @@ def from_forward_context( return cls( forward_mode=ctx.forward_mode, capture_hidden_mode=ctx.capture_hidden_mode, + gather_ids=ctx.gather_ids, extend_seq_lens=input_lengths, - padded_static_len=ctx.padded_static_len, - last_index_offsets=ctx.last_index_offsets, ) @@ -213,30 +209,15 @@ def forward( ) -> LogitsProcessorOutput: # Get the last hidden states and last logits for the next token prediction if not logits_metadata.extend_return_logprob: - if logits_metadata.forward_mode.is_extend_or_mixed(): - # Prefill: last token of each request via cumulative seq lens. - last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - pruned_states = hidden_states[last_index] - if aux_hidden_states is not None: - aux_pruned_states = [ - hidden[last_index] for hidden in aux_hidden_states - ] - elif logits_metadata.padded_static_len > 0: - # Padded per-request layout: pick the last valid token per - # request using the precomputed offsets. - last_index = ( - logits_metadata.last_index_offsets + logits_metadata.extend_seq_lens - ) - pruned_states = hidden_states[last_index] + gather_ids = logits_metadata.gather_ids + if gather_ids is None: + pruned_states = hidden_states if aux_hidden_states is not None: - aux_pruned_states = [ - hidden[last_index] for hidden in aux_hidden_states - ] + aux_pruned_states = list(aux_hidden_states) else: - # One row per request already — no indexing needed. - pruned_states = hidden_states + pruned_states = hidden_states[gather_ids] if aux_hidden_states is not None: - aux_pruned_states = [hidden for hidden in aux_hidden_states] + aux_pruned_states = [h[gather_ids] for h in aux_hidden_states] sample_indices = None input_logprob_indices = None diff --git a/python/tokenspeed/runtime/models/deepseek_v3.py b/python/tokenspeed/runtime/models/deepseek_v3.py index 7212b5ee1..9c8b5e634 100644 --- a/python/tokenspeed/runtime/models/deepseek_v3.py +++ b/python/tokenspeed/runtime/models/deepseek_v3.py @@ -26,7 +26,8 @@ import re from collections.abc import Iterable -from typing import Any +from dataclasses import replace +from typing import Any, Tuple import torch import torch.nn.functional as F @@ -66,6 +67,7 @@ from tokenspeed.runtime.distributed.comm_manager import CommManager from tokenspeed.runtime.execution.context import ForwardContext from tokenspeed.runtime.execution.cuda_graph_wrapper import get_is_capture_mode +from tokenspeed.runtime.execution.forward_batch_info import ForwardMode from tokenspeed.runtime.layers.activation import SiluAndMul from tokenspeed.runtime.layers.attention.mla_fp8_utils import ( mla_fused_rope_fp8_quantize, @@ -84,10 +86,7 @@ from tokenspeed.runtime.layers.moe.utils import RoutingMethodType from tokenspeed.runtime.layers.paged_attention import PagedAttention from tokenspeed.runtime.layers.quantization.base_config import QuantizationConfig -from tokenspeed.runtime.layers.quantization.utils import ( - block_dequant, - should_ignore_quant_layer, -) +from tokenspeed.runtime.layers.quantization.utils import block_dequant from tokenspeed.runtime.layers.rotary_embedding import get_rope from tokenspeed.runtime.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -328,12 +327,12 @@ def forward( num_global_tokens: int, max_num_tokens_per_gpu: int, ) -> torch.Tensor: - num_tokens = hidden_states.shape[0] + num_tokens = hidden_states.size(0) with self.stream_fork.scope(enable=get_is_capture_mode()) as fork: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - if hidden_states.shape[0] > 0: + if num_tokens > 0: topk_output = self.topk(hidden_states, router_logits) else: topk_output = self.topk.empty_topk_output( @@ -601,30 +600,6 @@ def __init__( self.w_kc = None self.w_vc = None - self.dense_1_unqaunted = ( - self.check_unquanted(self.q_b_proj) if hasattr(self, "q_b_proj") else True - ) - - def check_unquanted(self, module) -> bool: - return module.quant_config is None or should_ignore_quant_layer( - module.prefix, - ignored_layers=getattr(module.quant_config, "ignored_layers", []), - ) - - def no_absorb(self, ctx: ForwardContext) -> bool: - if self.attention_backend in self._MLA_KERNEL_BACKENDS: - # MLA kernel backends: do not absorb when ragged prefill is enabled. - return ( - not global_server_args_dict["mla_disable_ragged"] - and ctx.forward_mode.is_extend_or_mixed() - and ctx.padded_static_len == -1 - ) - else: - # Triton: skip absorption on pure prefill, absorb otherwise. - # `extend_prefix_lens_cpu` is not on ForwardContext, so the - # "no prefix cache" guard is conservative — always absorb. - return ctx.forward_mode.is_extend_or_mixed() and ctx.padded_static_len == -1 - def forward( self, positions: torch.Tensor, @@ -633,148 +608,102 @@ def forward( out_cache_loc: torch.Tensor, comm_manager: CommManager, block_scale: torch.Tensor | None = None, - can_run_flashinfer_fusion: bool = False, - ) -> torch.Tensor: - if self.no_absorb(ctx): - return self.forward_normal_chunked( - positions, hidden_states, ctx, out_cache_loc, comm_manager, block_scale - ) - else: - return self.forward_absorb( - positions, - hidden_states, - ctx, - out_cache_loc, - comm_manager, - block_scale, - can_run_flashinfer_fusion, - ) - - def forward_normal( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ctx: ForwardContext, - out_cache_loc: torch.Tensor, - comm_manager: CommManager, ) -> torch.Tensor: if self.q_lora_rank is not None: - qkv = self.fused_qkv_a_proj_with_mqa(hidden_states) + qkv = self.fused_qkv_a_proj_with_mqa( + hidden_states, block_scale, torch.bfloat16 + ) qkv = comm_manager.pre_attn_comm(qkv, ctx) - q, latent_cache = qkv.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + q_a, latent_cache = qkv.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, ) - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0] + kv_a = latent_cache[..., : self.kv_lora_rank] + q_norm = torch.empty_like(q_a) + if q_a.size(0) > 0: + self.fused_qk_layernorm( + input_q_a=q_a, input_kv_a=kv_a, output_q_a=q_norm + ) + q = self.q_b_proj(q_norm)[0] else: + hidden_states = comm_manager.pre_attn_comm(hidden_states, ctx) q = self.q_proj(hidden_states)[0] latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - - q = q.view(-1, self.num_local_heads, self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - kv_a, k_pe = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + kv_a = latent_cache[..., : self.kv_lora_rank] + self.kv_a_layernorm(kv_a, inplace=True) + + num_decodes = ctx.bs - ctx.num_extends + num_decode_tokens = num_decodes * ctx.attn_backend.spec_num_tokens + num_prefill_tokens = q.size(0) - num_decode_tokens + attn_output = torch.empty( + q.size(0), + self.num_local_heads * self.v_head_dim, + dtype=q.dtype, + device=q.device, ) - kv_a = self.kv_a_layernorm(kv_a) - kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - if self.rotary_emb is not None: - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - latent_cache[:, : self.kv_lora_rank] = kv_a - latent_cache[:, self.kv_lora_rank :] = k_pe - latent_cache = latent_cache.unsqueeze(1) - - # Save latent cache - ctx.token_to_kv_pool.set_kv_buffer( - self.attn_mha, out_cache_loc, latent_cache, None - ) + if ctx.num_extends > 0: + prefill_ctx = replace( + ctx, + bs=ctx.num_extends, + input_num_tokens=num_prefill_tokens, + forward_mode=ForwardMode.EXTEND, + ) + self.forward_normal_chunked( + positions[:num_prefill_tokens], + q[:num_prefill_tokens], + latent_cache[:num_prefill_tokens], + prefill_ctx, + out_cache_loc[:num_prefill_tokens], + attn_output[:num_prefill_tokens], + ) - q[..., self.qk_nope_head_dim :] = q_pe - k = torch.empty_like(q) - k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe.unsqueeze(1) - attn_output = self.attn_mha(q, k, v, ctx, out_cache_loc, save_kv_cache=False) + if ctx.num_extends < ctx.bs: + decode_ctx = replace( + ctx, + bs=num_decodes, + num_extends=0, + input_num_tokens=num_decode_tokens, + forward_mode=ForwardMode.DECODE, + ) + self.forward_absorb( + positions[num_prefill_tokens:], + q[num_prefill_tokens:], + latent_cache[num_prefill_tokens:], + decode_ctx, + out_cache_loc[num_prefill_tokens:], + attn_output[num_prefill_tokens:], + ) - attn_output = attn_output.view(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output def forward_absorb( self, positions: torch.Tensor, - hidden_states: torch.Tensor, + q: torch.Tensor, + latent_cache: torch.Tensor, ctx: ForwardContext, out_cache_loc: torch.Tensor, - comm_manager: CommManager, - block_scale: torch.Tensor | None = None, - can_run_flashinfer_fusion: bool = False, + output: torch.Tensor, ) -> torch.Tensor: Q, K = self.forward_absorb_qkv_proj( - hidden_states, + q, + latent_cache, positions, ctx, out_cache_loc, - comm_manager, - block_scale, - can_run_flashinfer_fusion, ) - output = self.forward_absorb_attn_o_proj(Q, K, ctx, out_cache_loc) - return output + return self.forward_absorb_attn_v_proj(Q, K, ctx, out_cache_loc, output) def forward_absorb_qkv_proj( self, - hidden_states, + q: torch.Tensor, + latent_cache: torch.Tensor, positions, ctx: ForwardContext, out_cache_loc: torch.Tensor, - comm_manager: CommManager, - block_scale: torch.Tensor | None = None, - can_run_flashinfer_fusion: bool | None = None, - ): - if self.q_lora_rank is not None: - qkv = self.fused_qkv_a_proj_with_mqa( - hidden_states, block_scale, torch.bfloat16 - ) - - if can_run_flashinfer_fusion and self.layer_id != 0: - qkv, q_norm, k_nope, block_scale = ( - self.fused_qk_layernorm.forward_with_allgather_fusion( - self.mapping.attn.tp_rank, - self.mapping.attn.tp_group, - qkv, - ctx.input_num_tokens, - fuse_block_quant_fp8=not self.dense_1_unqaunted, - ) - ) - latent_cache = qkv[..., self.q_lora_rank :] - q = self.q_b_proj(q_norm, block_scale, torch.bfloat16)[0] - else: - qkv = comm_manager.pre_attn_comm(qkv, ctx) - q, latent_cache = qkv.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - k_nope = latent_cache[..., : self.kv_lora_rank] - - # fused layernorm - q_norm = torch.empty_like(q) - if q.size(0) > 0: - self.fused_qk_layernorm( - input_q_a=q, input_kv_a=k_nope, output_q_a=q_norm - ) - - q = self.q_b_proj(q_norm)[0] - else: - hidden_states = comm_manager.pre_attn_comm(hidden_states, ctx) - q = self.q_proj(hidden_states)[0] - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - k_nope = latent_cache[..., : self.kv_lora_rank] - self.kv_a_layernorm(k_nope, inplace=True) - + ) -> Tuple[torch.Tensor, torch.Tensor]: q = q.view(-1, self.num_local_heads, self.qk_head_dim) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -785,7 +714,7 @@ def forward_absorb_qkv_proj( dtype=q_nope.dtype, device=q_nope.device, ) - # k_nope do the RMSNorm inplace, so latent_cache contain k_nope after norm and k_pe *before* rotate + # latent_cache contains normalized kv_a and k_pe before rotate. K = latent_cache.unsqueeze(1) q_nope_out_view = Q[..., : self.kv_lora_rank] torch.bmm( @@ -863,13 +792,14 @@ def forward_absorb_qkv_proj( return Q, K - def forward_absorb_attn_o_proj( + def forward_absorb_attn_v_proj( self, Q, K, ctx: ForwardContext, out_cache_loc: torch.Tensor, - ): + output: torch.Tensor, + ) -> torch.Tensor: # MLA kernel backends: KV cache already written in forward_absorb_qkv_proj. # Other backends: write via fused_set_kv_buffer or let backend handle it. if self.attention_backend in self._MLA_KERNEL_BACKENDS: @@ -886,74 +816,40 @@ def forward_absorb_attn_o_proj( save_kv_cache=need_save_kv, ) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - attn_bmm_output = torch.empty( - attn_output.size(0), - self.num_local_heads, - self.v_head_dim, - dtype=attn_output.dtype, - device=attn_output.device, - ) + output_view = output.view(-1, self.num_local_heads, self.v_head_dim) torch.bmm( attn_output.transpose(0, 1), self.w_vc, - out=attn_bmm_output.transpose(0, 1), + out=output_view.transpose(0, 1), ) - output, _ = self.o_proj(attn_bmm_output.flatten(1, 2)) return output def forward_normal_chunked( self, positions: torch.Tensor, - hidden_states: torch.Tensor, + q: torch.Tensor, + latent_cache: torch.Tensor, ctx: ForwardContext, out_cache_loc: torch.Tensor, - comm_manager: CommManager, - block_scale: torch.Tensor | None = None, - ): + output: torch.Tensor, + ) -> torch.Tensor: q, k, v = self.forward_normal_chunked_kv_prepare( - positions, hidden_states, ctx, out_cache_loc, comm_manager, block_scale + positions, q, latent_cache, ctx, out_cache_loc ) - output = self.forward_normal_chunked_kv_core(q, k, v, ctx, out_cache_loc) - return output + return self.forward_normal_chunked_kv_core(q, k, v, ctx, out_cache_loc, output) def forward_normal_chunked_kv_prepare( self, positions: torch.Tensor, - hidden_states: torch.Tensor, + q: torch.Tensor, + latent_cache: torch.Tensor, ctx: ForwardContext, out_cache_loc: torch.Tensor, - comm_manager: CommManager, - block_scale: torch.Tensor | None = None, - ): - if self.q_lora_rank is not None: - qkv = self.fused_qkv_a_proj_with_mqa( - hidden_states, block_scale, torch.bfloat16 - ) - qkv = comm_manager.pre_attn_comm(qkv, ctx) - q, kv_a, k_pe = qkv.split( - [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - q_norm = torch.empty_like(q) - kv_a_norm = torch.empty_like(kv_a) - if q.size(0) > 0: - self.fused_qk_layernorm( - input_q_a=q, - input_kv_a=kv_a, - output_q_a=q_norm, - output_kv_a=kv_a_norm, - ) - q = self.q_b_proj(q_norm)[0] - kv_a = kv_a_norm - k_pe = k_pe.unsqueeze(1) - else: - hidden_states = comm_manager.pre_attn_comm(hidden_states, ctx) - q = self.q_proj(hidden_states)[0] - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, k_pe = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - kv_a = self.kv_a_layernorm(kv_a) - k_pe = k_pe.unsqueeze(1) + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + kv_a, k_pe = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.unsqueeze(1) q = q.view(-1, self.num_local_heads, self.qk_head_dim) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -1041,7 +937,8 @@ def forward_normal_chunked_kv_core( v: torch.Tensor, ctx: ForwardContext, out_cache_loc: torch.Tensor, - ): + output: torch.Tensor, + ) -> torch.Tensor: attn_backend = ctx.attn_backend chunk_meta = attn_backend.chunked_prefill_metadata token_to_kv_pool = ctx.token_to_kv_pool @@ -1054,9 +951,11 @@ def forward_normal_chunked_kv_core( # Causal self-attention over the new chunk tokens. q_lens == kv_lens == # extend_seq_lens, so cum_seq_lens_q and cum_seq_lens_kv alias the same - # cum_extend_seq_lens. - num_extends = chunk_meta.extend_seq_lens.shape[0] - accum_output, accum_lse = attn_backend.forward_extend_chunked( + # cum_extend_seq_lens. Causal pass writes directly into output; each + # chunk's merge accumulates in place via merge_state(inplace=True). + num_extends = chunk_meta.extend_seq_lens.size(0) + output_view = output.view(-1, self.num_local_heads, self.v_head_dim) + _, accum_lse = attn_backend.forward_extend_chunked( q, k, v, @@ -1069,6 +968,7 @@ def forward_normal_chunked_kv_core( seq_lens=chunk_meta.extend_seq_lens, batch_size=num_extends, causal=True, + out=output_view, ) # Always read KV cache as BF16 for kv_b_proj (weight is BF16), even if Q is FP8. @@ -1104,7 +1004,7 @@ def forward_normal_chunked_kv_core( [k_nope, k_pe.expand(-1, self.num_local_heads, -1)], dim=-1 ) - output, lse = attn_backend.forward_extend_chunked( + chunk_output, lse = attn_backend.forward_extend_chunked( q, k, v, @@ -1119,12 +1019,15 @@ def forward_normal_chunked_kv_core( causal=False, ) - accum_output, accum_lse = merge_state( - accum_output, accum_lse, output, lse, enable_pdl=pdl_enabled() + merge_state( + output_view, + accum_lse, + chunk_output, + lse, + inplace=True, + enable_pdl=pdl_enabled(), ) - attn_output = accum_output.view(-1, self.num_local_heads * self.v_head_dim) - output, _ = self.o_proj(attn_output) return output @@ -1891,7 +1794,7 @@ def forward( ) comm_manager = self.midlayer.comm_manager - if comm_manager.should_fuse(hidden_states.shape[0]): + if comm_manager.should_fuse(hidden_states.size(0)): hidden_states_to_logits, hidden_states_to_aux, *_ = ( self.norm.forward_with_allreduce_fusion( self.mapping.dense.tp_rank, @@ -1980,13 +1883,12 @@ def prepare_model_kwargs( model_kwargs["captured_hidden_states"] = captured_hidden_states else: # During CUDA graph capture warmup, provide dummy hidden states. - num_tokens = input_ids.shape[0] target_hidden_size = getattr( self.config, "target_hidden_size", self.config.hidden_size ) num_fc = self.model.num_fc_input_dim model_kwargs["captured_hidden_states"] = torch.zeros( - num_tokens, + input_ids.size(0), target_hidden_size * num_fc, dtype=torch.bfloat16, device=input_ids.device, @@ -1997,7 +1899,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): remapped = [] for name, loaded_weight in weights: if "d2t" in name: - self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0]) + self.hot_token_id = loaded_weight + torch.arange(loaded_weight.size(0)) continue if "t2d" in name: continue diff --git a/python/tokenspeed/runtime/sampling/sampling_batch_info.py b/python/tokenspeed/runtime/sampling/sampling_batch_info.py index 5e8cac0e7..b3abc5692 100755 --- a/python/tokenspeed/runtime/sampling/sampling_batch_info.py +++ b/python/tokenspeed/runtime/sampling/sampling_batch_info.py @@ -77,20 +77,47 @@ class SamplingBatchInfo: # Device device: str = "cuda" + def __getitem__(self, s: slice) -> SamplingBatchInfo: + """Row-slice batch-indexed fields; pool/scalar fields pass through. + + Used by hybrid-batch samplers (MIXED + spec-dec) that apply + different sampler ops to a prefix vs suffix of rows. Only ``slice`` + is supported — int indexing would yield 0-dim tensors and break + downstream gathers. + + ``is_all_greedy`` is inherited from the parent; when ``top_ks`` is + populated the slice refines it from the sliced tensor (one GPU + sync, only on the disagg slice path). + """ + if not isinstance(s, slice): + raise TypeError( + f"SamplingBatchInfo only supports slice indexing, got {type(s).__name__}" + ) + + def _slice(t): + return t[s] if t is not None else None + + return dataclasses.replace( + self, + temperatures=_slice(self.temperatures), + top_ps=_slice(self.top_ps), + top_ks=_slice(self.top_ks), + min_ps=_slice(self.min_ps), + is_all_greedy=self.is_all_greedy, + req_pool_indices=_slice(self.req_pool_indices), + vocab_mask=_slice(self.vocab_mask), + grammars=_slice(self.grammars), + ) + @classmethod def from_schedule_batch( cls, batch: ScheduleBatch, vocab_size: int ) -> SamplingBatchInfo: reqs = batch.reqs device = batch.device - temperatures = ( - torch.tensor( - [r.sampling_params.temperature for r in reqs], - dtype=torch.float, - ) - .view(-1, 1) - .to(device, non_blocking=True) - ) + temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], dtype=torch.float + ).to(device, non_blocking=True) top_ps = torch.tensor( [r.sampling_params.top_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml index 80e0d8a1f..0375764b4 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-aime25.yaml @@ -18,6 +18,7 @@ server: --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 + --enable-mixed-batch --trust-remote-code --attention-backend tokenspeed_mla --moe-backend flashinfer_trtllm @@ -25,9 +26,9 @@ server: --kv-cache-dtype fp8 --weight-loader-prefetch-checkpoints --speculative-algorithm EAGLE3 - --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3 + --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3-mla --speculative-num-steps 1 - --drafter-attention-backend trtllm + --drafter-attention-backend tokenspeed_mla --enable-cache-report --host 127.0.0.1 --port 8000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml index a57048ed8..a1b3cd5fe 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gpqa-diamond.yaml @@ -17,6 +17,7 @@ server: --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 + --enable-mixed-batch --trust-remote-code --attention-backend tokenspeed_mla --moe-backend flashinfer_trtllm @@ -24,9 +25,9 @@ server: --kv-cache-dtype fp8 --weight-loader-prefetch-checkpoints --speculative-algorithm EAGLE3 - --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3 + --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3-mla --speculative-num-steps 1 - --drafter-attention-backend trtllm + --drafter-attention-backend tokenspeed_mla --enable-cache-report --host 127.0.0.1 --port 8000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml index f85701177..64693b3fb 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-gsm8k.yaml @@ -17,6 +17,7 @@ server: --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 + --enable-mixed-batch --trust-remote-code --attention-backend tokenspeed_mla --moe-backend flashinfer_trtllm @@ -24,9 +25,9 @@ server: --kv-cache-dtype fp8 --weight-loader-prefetch-checkpoints --speculative-algorithm EAGLE3 - --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3 + --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3-mla --speculative-num-steps 1 - --drafter-attention-backend trtllm + --drafter-attention-backend tokenspeed_mla --enable-cache-report --host 127.0.0.1 --port 8000 diff --git a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml index 3d960eb79..3227a470b 100644 --- a/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml +++ b/test/ci/eval/kimi-k2.5-nvfp4-evalscope-mmlu.yaml @@ -17,6 +17,7 @@ server: --model nvidia/Kimi-K2.5-NVFP4 --tp 4 --max-model-len 80000 + --enable-mixed-batch --trust-remote-code --attention-backend tokenspeed_mla --moe-backend flashinfer_trtllm @@ -24,9 +25,9 @@ server: --kv-cache-dtype fp8 --weight-loader-prefetch-checkpoints --speculative-algorithm EAGLE3 - --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3 + --speculative-draft-model-path lightseekorg/kimi-k2.5-eagle3-mla --speculative-num-steps 1 - --drafter-attention-backend trtllm + --drafter-attention-backend tokenspeed_mla --enable-cache-report --host 127.0.0.1 --port 8000 diff --git a/test/runtime/test_deepseek_v4_config.py b/test/runtime/test_deepseek_v4_config.py index 797a2fe86..8f01065cc 100644 --- a/test/runtime/test_deepseek_v4_config.py +++ b/test/runtime/test_deepseek_v4_config.py @@ -460,8 +460,8 @@ def init_forward_metadata_replay_cuda_graph(self, *args, **kwargs): }, ) - # num_tokens = padded_bs * max_tokens_per_req is passed as 2nd positional. - self.assertEqual(captured["args"][1], 4) + # padded_bs is the first positional arg. + self.assertEqual(captured["args"][0], 4) self.assertEqual(captured["kwargs"]["actual_bs"], 0) self.assertEqual( captured["kwargs"]["paged_cache_block_tables"]["v4.swa"].shape, @@ -968,17 +968,17 @@ def test_deepseek_v4_kv_pool_allocates_v4_cache_families(self): max_scheduled_tokens=1, ) - self.assertEqual(tuple(pool.get_swa_kv_buffer(0).shape), (7, 37440)) + self.assertEqual(tuple(pool.get_swa_kv_buffer(0).shape), (8, 37440)) self.assertIsNone(pool.compressed_kv_buffer[0]) self.assertEqual(tuple(pool.get_compressed_kv_buffer_2d(1).shape), (4, 37440)) - self.assertEqual(tuple(pool.get_compressor_state_buffer(1).shape), (7, 4, 2048)) + self.assertEqual(tuple(pool.get_compressor_state_buffer(1).shape), (8, 4, 2048)) self.assertEqual( - tuple(pool.get_compressor_state_buffer(2).shape), (35, 8, 1024) + tuple(pool.get_compressor_state_buffer(2).shape), (36, 8, 1024) ) self.assertEqual(pool.get_compressor_state_buffer(1).dtype, torch.float32) self.assertEqual(pool.get_compressor_state_buffer(2).dtype, torch.float32) self.assertEqual(tuple(pool.get_indexer_kv_buffer_2d(1).shape), (4, 64 * 68)) - self.assertEqual(tuple(pool.get_indexer_state_buffer(1).shape), (7, 4, 512)) + self.assertEqual(tuple(pool.get_indexer_state_buffer(1).shape), (8, 4, 512)) self.assertEqual(pool.get_indexer_state_buffer(1).dtype, torch.float32) def test_deepseek_v4_kv_pool_uses_compressed_storage_blocks_for_page256(self): @@ -1066,6 +1066,7 @@ def test_deepseek_v4_backend_preserves_compact_paged_cache_contract(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=4096, ) @@ -1075,7 +1076,6 @@ def test_deepseek_v4_backend_preserves_compact_paged_cache_contract(self): backend.init_forward_metadata( bs=2, - num_tokens=2, req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), seq_lens=torch.tensor([200, 80], dtype=torch.int32), forward_mode=ForwardMode.DECODE, @@ -1100,6 +1100,7 @@ def test_deepseek_v4_mixed_metadata_keeps_decode_rows_single_token(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=4096, ) @@ -1107,7 +1108,6 @@ def test_deepseek_v4_mixed_metadata_keeps_decode_rows_single_token(self): backend.init_forward_metadata( bs=3, - num_tokens=10, req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int64), seq_lens=torch.tensor([7, 10, 4], dtype=torch.int32), forward_mode=ForwardMode.MIXED, @@ -1140,6 +1140,7 @@ def test_deepseek_v4_cuda_graph_refresh_keeps_compact_table_columns(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=4096, ) @@ -1178,6 +1179,7 @@ def test_deepseek_v4_metadata_splits_named_cache_groups(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=4096, ) @@ -1192,7 +1194,6 @@ def test_deepseek_v4_metadata_splits_named_cache_groups(self): backend.init_forward_metadata( bs=2, - num_tokens=2, req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), seq_lens=torch.tensor([200, 80], dtype=torch.int32), forward_mode=ForwardMode.DECODE, @@ -1244,6 +1245,7 @@ def test_deepseek_v4_metadata_slice_preserves_compact_base_offsets(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=4096, ) @@ -1444,13 +1446,13 @@ def test_deepseek_v4_mixed_metadata_splits_prefill_and_decode(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=576, context_len=256, ) ) backend.init_forward_metadata( bs=3, - num_tokens=5, req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), forward_mode=ForwardMode.MIXED, @@ -1530,13 +1532,13 @@ def test_deepseek_v4_mixed_metadata_accepts_prefill_prefix_lens_only(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=576, context_len=256, ) ) backend.init_forward_metadata( bs=4, - num_tokens=8, req_pool_indices=torch.tensor([0, 1, 2, 3], dtype=torch.int32), seq_lens=torch.tensor([5, 9, 12, 6], dtype=torch.int32), forward_mode=ForwardMode.MIXED, @@ -1575,13 +1577,13 @@ def test_deepseek_v4_mixed_backend_slices_prefill_and_decode(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=576, context_len=256, ) ) backend.init_forward_metadata( bs=3, - num_tokens=5, req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), forward_mode=ForwardMode.MIXED, @@ -1669,6 +1671,7 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=128, ) @@ -1676,7 +1679,6 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): seq_lens = torch.tensor([70, 3], dtype=torch.int32) backend.init_forward_metadata( bs=2, - num_tokens=2, req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), seq_lens=seq_lens, forward_mode=ForwardMode.DECODE, @@ -1708,7 +1710,6 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): seq_lens = torch.tensor([256, 129], dtype=torch.int32) backend.init_forward_metadata( bs=2, - num_tokens=2, req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), seq_lens=seq_lens, forward_mode=ForwardMode.DECODE, @@ -1755,6 +1756,7 @@ def test_deepseek_v4_decode_backend_capture_ignores_warmup_cache(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=128, ) @@ -1762,7 +1764,6 @@ def test_deepseek_v4_decode_backend_capture_ignores_warmup_cache(self): seq_lens = torch.tensor([128, 64], device=device, dtype=torch.int32) backend.init_forward_metadata( bs=2, - num_tokens=2, req_pool_indices=torch.tensor([0, 1], device=device, dtype=torch.int64), seq_lens=seq_lens, forward_mode=ForwardMode.DECODE, @@ -1816,6 +1817,7 @@ def test_deepseek_v4_c128a_prefill_local_compressed_indices_contract(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=1024, ) @@ -1933,6 +1935,7 @@ def test_deepseek_v4_decode_backend_masks_padding_tokens(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=128, ) @@ -1940,7 +1943,6 @@ def test_deepseek_v4_decode_backend_masks_padding_tokens(self): seq_lens = torch.tensor([70, 3], dtype=torch.int32) backend.init_forward_metadata( bs=2, - num_tokens=2, req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), seq_lens=seq_lens, forward_mode=ForwardMode.DECODE, @@ -1997,6 +1999,7 @@ def test_deepseek_v4_cuda_graph_replay_marks_padding_tokens_invalid(self): attn_tp_size=1, dtype=torch.bfloat16, is_draft=False, + speculative_num_draft_tokens=1, head_dim=512, context_len=128, ) @@ -2004,7 +2007,6 @@ def test_deepseek_v4_cuda_graph_replay_marks_padding_tokens_invalid(self): backend.init_cuda_graph_state(max_bs=4) backend.init_forward_metadata_capture_cuda_graph( bs=4, - num_tokens=4, req_pool_indices=torch.arange(4, dtype=torch.int32), seq_lens=torch.ones(4, dtype=torch.int32), forward_mode=ForwardMode.DECODE, @@ -2012,7 +2014,6 @@ def test_deepseek_v4_cuda_graph_replay_marks_padding_tokens_invalid(self): backend.init_forward_metadata_replay_cuda_graph( bs=4, - num_tokens=4, actual_bs=2, req_pool_indices=torch.arange(4, dtype=torch.int32), seq_lens=torch.tensor([70, 3, 1, 1], dtype=torch.int32), diff --git a/test/runtime/test_server_args_attention_backends.py b/test/runtime/test_server_args_attention_backends.py index ad85391b4..6940d8be5 100644 --- a/test/runtime/test_server_args_attention_backends.py +++ b/test/runtime/test_server_args_attention_backends.py @@ -122,6 +122,7 @@ def test_mha_config_propagates_speculative_settings(self): block_size=64, max_cudagraph_capture_size=4, kv_cache_quant_method="none", + speculative_algorithm="EAGLE3", speculative_num_steps=3, speculative_num_draft_tokens=4, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/merge_state.cu b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/merge_state.cu index 87d7c5ec0..1edc3839a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/merge_state.cu +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/merge_state.cu @@ -27,53 +27,76 @@ namespace tokenspeed { using flashinfer::vec_t; namespace math = flashinfer::math; -template -__global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__ s_a, - DTypeIn* __restrict__ v_b, float* __restrict__ s_b, - DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, - uint32_t num_heads, uint32_t head_dim, +// In-place safe: v_merged may alias v_a, s_merged may alias s_a. See block +// comment in the launcher for the contract. +template +__global__ void MergeStateKernel(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, + DTypeO* v_merged, float* s_merged, uint32_t num_heads, float lse_scale_log2, float lse_scale_inv) { -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif + constexpr size_t kVecSize = std::max(16U / sizeof(DTypeIn), HeadDim / 32U); + constexpr size_t kBdx = HeadDim / kVecSize; + uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t pos = blockIdx.x; uint32_t head_idx = ty; +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + // Load phase: snapshot every aliasable input into registers before any store fires. float s_a_val = s_a[pos * num_heads + head_idx] * lse_scale_log2; float s_b_val = s_b[pos * num_heads + head_idx] * lse_scale_log2; + vec_t v_a_vec, v_b_vec, v_merged_vec; + v_a_vec.cast_load(v_a + (pos * num_heads + head_idx) * HeadDim + tx * kVecSize); + v_b_vec.cast_load(v_b + (pos * num_heads + head_idx) * HeadDim + tx * kVecSize); + + // Compute phase: register-only. float s_max = max(s_a_val, s_b_val); s_a_val = math::ptx_exp2(s_a_val - s_max); s_b_val = math::ptx_exp2(s_b_val - s_max); float a_scale = s_a_val / (s_a_val + s_b_val); float b_scale = s_b_val / (s_a_val + s_b_val); - vec_t v_a_vec, v_b_vec, v_merged_vec; - v_a_vec.cast_load(v_a + (pos * num_heads + head_idx) * head_dim + tx * vec_size); - v_b_vec.cast_load(v_b + (pos * num_heads + head_idx) * head_dim + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { + for (uint32_t i = 0; i < kVecSize; ++i) { v_merged_vec[i] = a_scale * v_a_vec[i] + b_scale * v_b_vec[i]; } - v_merged_vec.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); - s_merged[pos * num_heads + head_idx] = - (math::ptx_log2(s_a_val + s_b_val) + s_max) * lse_scale_inv; + + // v_merged store: per-lane disjoint slice, no cross-lane ordering needed. + v_merged_vec.cast_store(v_merged + (pos * num_heads + head_idx) * HeadDim + tx * kVecSize); + + // s_merged store: kBdx lanes share one slot. Sync so every lane's s_a load + // is complete before the writer fires, then a single lane writes. + if constexpr (kBdx <= 32) { + __syncwarp(); + } else { + __syncthreads(); + } + if (tx == 0) { + s_merged[pos * num_heads + head_idx] = + (math::ptx_log2(s_a_val + s_b_val) + s_max) * lse_scale_inv; + } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } +// Aliasing contract: v_merged may alias v_a, s_merged may alias s_a. The kernel +// reorders all aliasable inputs into a register-only snapshot phase, then +// stores. The s_merged write is single-writer per (pos, head_idx) and guarded +// by __syncwarp/__syncthreads so cross-lane s_a reads finish before the aliased +// store fires. template cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DTypeO* v_merged, float* s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, float lse_scale_log2, float lse_scale_inv, bool enable_pdl, cudaStream_t stream = nullptr) { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); - uint32_t bdx = HEAD_DIM / vec_size; + DISPATCH_HEAD_DIM(head_dim, HeadDim, { + constexpr size_t kVecSize = std::max(16U / sizeof(DTypeIn), HeadDim / 32U); + constexpr size_t kBdx = HeadDim / kVecSize; uint32_t bdy = num_heads; dim3 nblks(seq_len); - dim3 nthrs(bdx, bdy); - auto kernel = MergeStateKernel; + dim3 nthrs(static_cast(kBdx), bdy); + auto kernel = MergeStateKernel; cudaLaunchConfig_t config; config.gridDim = nblks; @@ -87,7 +110,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType config.attrs = attrs; FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v_a, s_a, v_b, s_b, v_merged, s_merged, - num_heads, head_dim, lse_scale_log2, lse_scale_inv)); + num_heads, lse_scale_log2, lse_scale_inv)); }); return cudaSuccess; } diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/merge_state.py b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/merge_state.py index 554c87092..41086a213 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/merge_state.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/merge_state.py @@ -60,6 +60,7 @@ def merge_state( v_b: torch.Tensor, s_b: torch.Tensor, *, + inplace: bool = False, lse_scale_log2: float = LSE_LN, enable_pdl: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -81,6 +82,9 @@ def merge_state( ---------- v_a, v_b : ``[seq_len, num_heads, head_dim]``, ``bfloat16`` or ``float16``. s_a, s_b : ``[seq_len, num_heads]``, must be ``float32``. + inplace : when ``True``, the merged output is written back into ``v_a`` and + ``s_a`` (mutated) and the same tensors are returned. When ``False``, + fresh buffers are allocated. lse_scale_log2 : multiplier mapping caller's LSE basis to log2. enable_pdl : opt into Programmatic Dependent Launch (Hopper+). Caller must also enable PDL on the upstream / downstream kernels for the overlap @@ -105,9 +109,12 @@ def merge_state( s_a.dtype == torch.float32 and s_b.dtype == torch.float32 ), f"merge_state expects fp32 LSE, got s_a={s_a.dtype} s_b={s_b.dtype}" - seq_len, num_heads, _ = v_a.shape - v_merged = torch.empty_like(v_a) - s_merged = torch.empty(seq_len, num_heads, dtype=torch.float32, device=v_a.device) + if inplace: + v_merged = v_a + s_merged = s_a + else: + v_merged = torch.empty_like(v_a) + s_merged = torch.empty_like(s_a) _load_merge_state_module().merge_state( v_a, s_a, diff --git a/tokenspeed-mla/python/tokenspeed_mla/mla_prefill.py b/tokenspeed-mla/python/tokenspeed_mla/mla_prefill.py index c89524cbf..da991f652 100644 --- a/tokenspeed-mla/python/tokenspeed_mla/mla_prefill.py +++ b/tokenspeed-mla/python/tokenspeed_mla/mla_prefill.py @@ -30,9 +30,7 @@ import logging import math import os - -LOG2_E = math.log2(math.exp(1.0)) # ≈ 1.4426950408889634 -from typing import Callable, Optional, Tuple +from typing import Optional, Tuple import cutlass import cutlass.cute as cute @@ -49,6 +47,7 @@ from tokenspeed_mla.utils import torch_to_cutlass_dtype logger = logging.getLogger(__name__) +LOG2_E = math.log2(math.exp(1.0)) # ≈ 1.4426950408889634 # Backend selection via env var. Values: "cutedsl" (default) or "binary" (AOT SO). _PREFILL_BACKEND_ENV = os.environ.get( @@ -297,6 +296,7 @@ def tokenspeed_mla_prefill( cum_seq_lens_q: Optional[torch.Tensor] = None, max_seq_len_q: Optional[int] = None, enable_pdl: bool = False, + out: Optional[torch.Tensor] = None, ) -> "torch.Tensor | Tuple[torch.Tensor, torch.Tensor]": """CuTe DSL FMHA prefill kernel for MLA on Blackwell SM100. @@ -304,6 +304,7 @@ def tokenspeed_mla_prefill( Q shape: [sum(q_lens), h_q, d_qk] K shape: [sum(kv_lens), h_k, d_qk] V shape: [sum(kv_lens), h_k, d_v] + If provided, out must be contiguous BF16 with shape [sum(q_lens), h_q, d_v]. """ total_q_tokens, h_q, d_qk = query.shape total_kv_tokens, h_k, _ = key.shape @@ -347,9 +348,21 @@ def tokenspeed_mla_prefill( # Output: BF16, same 5D layout. The kernel writes (out=0, lse=-inf) for # rows in batches where seqlen_k==0, so no pre-init is required. - o_torch = torch.empty( - total_q_tokens, h_q, d_v, dtype=out_torch_dtype, device=query.device - ) + if out is None: + o_torch = torch.empty( + total_q_tokens, h_q, d_v, dtype=out_torch_dtype, device=query.device + ) + else: + expected_shape = (total_q_tokens, h_q, d_v) + if out.shape != expected_shape: + raise ValueError(f"out shape must be {expected_shape}, got {out.shape}") + if out.dtype != out_torch_dtype: + raise TypeError(f"out dtype must be {out_torch_dtype}, got {out.dtype}") + if out.device != query.device: + raise ValueError(f"out device {out.device} must match query {query.device}") + if not out.is_contiguous(): + raise ValueError("out must be contiguous") + o_torch = out o_5d = o_torch.view(1, total_q_tokens, h_k, h_r, d_v) o_ct = _to_cute(o_5d, out_cutlass_dtype) diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index 192e5d251..d2cee80ab 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -472,27 +473,30 @@ DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::Schedu std::tuple, std::variant, std::vector>> Scheduler::newForwardOperation(std::vector candidates) { auto priority = [&](const Request* req) -> int { - if (req->Is()) return 0; - if (req->Is()) return 1; - if (req->Is() || req->Is()) return 2; - if (req->Is()) return 3; - return 4; + if (req->Is()) return 1; + if (req->Is()) return 2; + if (req->Is() || req->Is()) { + // Decode-first if mixed-batch is enabled; prefill-first otherwise. + return config_.enable_mixed_prefill_decode ? 0 : 3; + } + if (req->Is()) return 4; + return 9; }; std::sort(candidates.begin(), candidates.end(), [&](const auto& a, const auto& b) { return priority(a) < priority(b); }); std::vector ops; std::int32_t token_budget = config_.max_scheduled_tokens; + bool pushed_prefill = false; auto push_op = [&](auto op, bool uses_pool_slot = false) { if (config_.role != Role::kD) { token_budget -= op.input_length; } + if constexpr (std::is_same_v, PrefillOperation>) { + pushed_prefill = true; + } ops.push_back(std::move(op)); }; - auto has_prefill_op = [&]() { - return std::any_of(ops.begin(), ops.end(), - [](const ForwardOperation& op) { return std::holds_alternative(op); }); - }; std::vector loadback_ops; auto simulated_free = hybrid_prefix_cache_ ? hybrid_prefix_cache_->InitialSimulatedFree() : std::map{}; @@ -519,14 +523,16 @@ Scheduler::newForwardOperation(std::vector candidates) { } } } else if (request->Is() || (request->Is() && config_.role != Role::kP)) { - // Prefill-first: skip ALL decode if any prefill was scheduled this round. - if (!config_.enable_mixed_prefill_decode && has_prefill_op()) break; + // If mixed-batch is disabled, skip ALL decode if any prefill was scheduled this round. + // If mixed-batch is enabled, the priority sort puts decodes first, so this + // branch is reached before any prefill push. + if (!config_.enable_mixed_prefill_decode && pushed_prefill) break; if (auto ev = scheduleDecode(request, simulated_free)) { push_op(applyEventAndGenerateOp(request, *ev)); } } else if (request->Is() && config_.role != Role::kP) { - if (!config_.enable_mixed_prefill_decode && has_prefill_op()) break; + if (!config_.enable_mixed_prefill_decode && pushed_prefill) break; if (auto ev = scheduleDecodeFromRetracted(request, simulated_free)) { std::vector loadback_diff = ev->GetLoadbackDiff(); diff --git a/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py b/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py index ab831137f..bdf4565f4 100644 --- a/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py +++ b/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py @@ -300,6 +300,27 @@ def test_mixed_prefill_decode_can_schedule_decode_with_new_prefill(self): assert len(op.input_ids) + len(op.decode_input_ids) == sum(op.input_lengths) assert op.sizes == [1, 0] + def test_mixed_prefill_decode_decode_not_starved_by_long_prefill(self): + """Decode-first priority: active decode is scheduled even when a long prefill would consume the full budget.""" + cfg = make_config(max_scheduled_tokens=16, max_batch_size=8) + cfg.enable_mixed_prefill_decode = True + s = Scheduler(cfg) + + submit(s, "r0", list(range(8))) + s.next_execution_plan() # r0 → PrefillDone + s.next_execution_plan() # r0 → Decoding + advance_forward(s, "r0", tokens=[99]) + + submit(s, "r1", list(range(32))) # 32 > budget=16 + plan = s.next_execution_plan() + op = plan.forward[0] + + # Layout is prefill-first/decode-second (FlatForwardOperation::stable_partition). + assert op.request_ids == ["r1", "r0"] + assert op.num_extends() == 1 + # r0 decode = 1 token; r1 prefill chunk takes the remaining 15. + assert op.input_lengths == [15, 1] + def test_max_batch_size_limits_scheduled_requests(self): """max_batch_size caps the number of requests per plan.""" s = Scheduler(make_config(max_scheduled_tokens=512, max_batch_size=2))