diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index 859002b47..e3ab975f4 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -640,6 +640,16 @@ def _setup_layerwise_loadback(self, execution_plan) -> None: if host_exec is not None: self.model_executor.execution_stream.wait_stream(host_exec.write_stream) + def _flush_mamba_retract_states(self, forward_op) -> None: + """Copy draft->working mamba states when retract occurred (no forward scheduled).""" + if forward_op is not None: + return + if self.model_executor.drafter is None: + return + if self.model_executor.runtime_states.mamba_pool is None: + return + self.model_executor.flush_mamba_draft_to_working_on_retract() + # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @@ -930,6 +940,7 @@ def event_loop(self): self._submit_cache_ops(execution_plan) forward_op = self._get_forward_op(execution_plan) + self._flush_mamba_retract_states(forward_op) stats = self._get_scheduler_stats() num_iter_tokens = ( @@ -1044,6 +1055,7 @@ def event_loop_overlap(self): self._submit_cache_ops(execution_plan) forward_op = self._get_forward_op(execution_plan) + self._flush_mamba_retract_states(forward_op) stats = self._get_scheduler_stats() num_iter_tokens = ( @@ -1112,12 +1124,6 @@ def event_loop_overlap(self): curr_results = None if forward_op is not None: - if forward_op.num_extends() <= 0: - # Overlap dispatch may schedule one extra decode before - # the previous result is committed. Snapshot the completed - # working state before this decode mutates the same slot; - # the snapshot helper only copies block-aligned states. - self.model_executor.snapshot_mamba_checkpoints_for_op(forward_op) curr_results, _ = self._dispatch_forward( forward_op, sampling_params_list, diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 2f1c61d12..0f6896beb 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -291,6 +291,8 @@ def __init__( self.execution_stream = torch.cuda.Stream() self.log_step = 0 self._seen_prefill_ids: set[str] = set() + self._prev_decode_bs: int = 0 + self._sentinel_neg1 = torch.tensor(-1, device=self.device, dtype=torch.int64) # Decode stats — accumulated from synced results (no GPU sync needed) self.num_generated_tokens = 0 self.num_decode_steps = 0 @@ -467,48 +469,141 @@ def accumulate_decode_stats(self, results: ModelExecutionResult, bs: int): self.num_generated_tokens += int(results.output_lengths.sum().item()) self.num_decode_steps += bs - def snapshot_mamba_checkpoints_for_op(self, forward_op) -> None: - """Snapshot completed decode working states into checkpoint slots.""" - if self.runtime_states.mamba_pool is None or forward_op.num_extends() > 0: + def _snapshot_mamba_checkpoints( + self, + accept_lengths: torch.Tensor, + bs: int, + num_extends: int, + ) -> None: + """Snapshot mamba states to checkpoint slots at page boundaries. + + Called after ``_update_runtime_state`` on the execution stream so + ``valid_cache_lengths`` already reflects the accepted tokens. + + Non-MTP (accept_length == 1): + The working slot holds the up-to-date state for the new + cache_length. Pass the kernel page_size so it copies only + when the new length is page-aligned. + + MTP (accept_length > 1): + cache_length may jump over a page boundary. The intermediate + state lives in ``mamba_output_indices[req, step]``. Boundary + detection and source-slot selection are done entirely on GPU + with -1 sentinels so the snapshot kernel skips invalid entries + via its bounds check — no GPU-to-CPU sync, preserving + overlap-schedule pipelining. + """ + if self.runtime_states.mamba_pool is None or num_extends > 0: return - if not getattr(forward_op, "mamba_pool_indices", None): + if not self.input_buffers.has_mamba: return - # CPU-side pre-filter - src_list = [] - dst_list = [] - req_list = [] - for i in range(len(forward_op.request_ids)): - pool_idx = forward_op.mamba_pool_indices[i] - ckpt_idx = forward_op.mamba_track_pool_indices[i] - if pool_idx != -1 and ckpt_idx != -1: - src_list.append(pool_idx) - dst_list.append(ckpt_idx) - req_list.append(forward_op.request_pool_indices[i]) - - num_valid = len(src_list) - if num_valid == 0: + req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs] + track_indices = self.input_buffers.mamba_track_pool_indices_buf[:bs] + page_size = self.config.block_size + dev = req_pool_indices.device + sentinel = self._sentinel_neg1 + + if self.drafter is not None: + # -- MTP path: find the output slot at the crossed boundary -- + backend = getattr( + self.attn_backend, "linear_attn_backend", self.attn_backend + ) + fm = getattr(backend, "forward_metadata", None) + if fm is None: + return + output_indices = fm.mamba_output_indices + if output_indices is None: + return + + new_cl = self.runtime_states.valid_cache_lengths[req_pool_indices] + old_cl = new_cl - accept_lengths[:bs].to(device=dev, dtype=new_cl.dtype) + first_boundary = ((old_cl // page_size) + 1) * page_size + + step_raw = first_boundary - old_cl - 1 + max_col = output_indices.shape[1] - 1 + step = step_raw.clamp(min=0, max=max_col).to(torch.int64) + + req_range = torch.arange(bs, device=dev) + src_raw = output_indices[req_range, step].to(torch.int64) + dst_raw = track_indices.to(device=dev, dtype=torch.int64) + + invalid = ( + (first_boundary > new_cl) + | (dst_raw < 0) + | (src_raw < 0) + | (src_raw == dst_raw) + | (step_raw < 0) + ) + src = torch.where(invalid, sentinel, src_raw) + dst = torch.where(invalid, sentinel, dst_raw) + + self.runtime_states.snapshot_mamba_checkpoints( + src, + dst, + cache_lengths=None, + page_size=0, + num_valid=bs, + ) + else: + # -- Non-MTP path: working slot IS the up-to-date state -- + src_raw = self.input_buffers.mamba_pool_indices_buf[:bs].to( + device=dev, dtype=torch.int64 + ) + dst_raw = track_indices.to(device=dev, dtype=torch.int64) + + invalid = (src_raw < 0) | (dst_raw < 0) | (src_raw == dst_raw) + src = torch.where(invalid, sentinel, src_raw) + dst = torch.where(invalid, sentinel, dst_raw) + + cache_lengths = self.runtime_states.valid_cache_lengths[req_pool_indices] + self.runtime_states.snapshot_mamba_checkpoints( + src, + dst, + cache_lengths=cache_lengths, + page_size=page_size, + num_valid=bs, + ) + + def flush_mamba_draft_to_working_on_retract(self) -> None: + """Copy accepted draft mamba state -> working slot for all previous-batch requests. + + Called from event_loop when retract WriteBackOps are detected. + Uses the previous decode iteration's input_buffers (still valid since + no new forward has overwritten them). + Runs on execution_stream to respect ordering with previous forward writes. + """ + bs = self._prev_decode_bs + if bs <= 0: return - t_src = torch.tensor(src_list, dtype=torch.int64, device="cpu", pin_memory=True) - t_dst = torch.tensor(dst_list, dtype=torch.int64, device="cpu", pin_memory=True) - t_req = torch.tensor(req_list, dtype=torch.int64, device="cpu", pin_memory=True) - - src_buf = torch.empty(num_valid, dtype=torch.int64, device=self.device) - dst_buf = torch.empty(num_valid, dtype=torch.int64, device=self.device) - req_buf = torch.empty(num_valid, dtype=torch.int64, device=self.device) - src_buf.copy_(t_src, non_blocking=True) - dst_buf.copy_(t_dst, non_blocking=True) - req_buf.copy_(t_req, non_blocking=True) - - cache_lengths = self.runtime_states.valid_cache_lengths[req_buf] - self.runtime_states.snapshot_mamba_checkpoints( - src_buf, - dst_buf, - cache_lengths, - self.config.block_size, - num_valid, - ) + backend = getattr(self.attn_backend, "linear_attn_backend", self.attn_backend) + pool = getattr(backend, "pool", None) + if pool is None: + return + + sentinel = self._sentinel_neg1 + + with torch.cuda.stream(self.execution_stream): + req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs] + working = self.input_buffers.mamba_pool_indices_buf[:bs] + + src_raw = pool.current_input_indices[req_pool_indices.clamp(0).long()].to( + dtype=torch.int64 + ) + dst_raw = working.to(dtype=torch.int64) + + invalid = (src_raw < 0) | (dst_raw < 0) | (src_raw == dst_raw) + src = torch.where(invalid, sentinel, src_raw) + dst = torch.where(invalid, sentinel, dst_raw) + + self.runtime_states.snapshot_mamba_checkpoints( + src, + dst, + cache_lengths=None, + page_size=0, + num_valid=bs, + ) def execute_forward_op_with_log( self, @@ -822,6 +917,9 @@ def execute_forward_op( bs = len(forward_op.request_ids) forward_mode = ForwardMode.from_num_extends(num_extends, bs) + if num_extends <= 0: + self._prev_decode_bs = bs + if self.runtime_states.mamba_pool is not None and ( num_extends > 0 or has_retract ): @@ -834,9 +932,21 @@ def execute_forward_op( self.runtime_states.zero_mamba_states( mamba_pool_indices, mamba_cow_src, - self.input_buffers.extend_prefix_lens_buf[:bs], - bs, + self.input_buffers.extend_prefix_lens_buf[:num_extends], + num_extends, ) + if hasattr(self.attn_backend, "reset_current_inputs"): + self.attn_backend.reset_current_inputs( + self.input_buffers.req_pool_indices_buf[:num_extends], + mamba_pool_indices[:num_extends], + ) + elif has_retract: + if hasattr(self.attn_backend, "reset_current_inputs"): + retract_mask = mamba_cow_src[:bs] >= 0 + self.attn_backend.reset_current_inputs( + self.input_buffers.req_pool_indices_buf[:bs][retract_mask], + mamba_pool_indices[:bs][retract_mask], + ) grammar_completion = None @@ -967,6 +1077,11 @@ def execute_forward_op( input_lengths=self.input_buffers.input_lengths_buf[:bs], is_extend=num_extends > 0, ) + self._snapshot_mamba_checkpoints( + output_lengths, + bs, + num_extends, + ) with nvtx_range("output_d2h", color="green"): output_tokens = output_tokens.to("cpu", non_blocking=True) diff --git a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py index 837e4d54c..a8b5133ac 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py +++ b/python/tokenspeed/runtime/layers/attention/backends/hybrid_linear_attn.py @@ -41,9 +41,6 @@ fused_sigmoid_gating_delta_rule_update, ) from tokenspeed.runtime.layers.attention.linear.gdn import fused_gdn_gating -from tokenspeed.runtime.layers.attention.linear.mamba_state_scatter_triton import ( - fused_mamba_state_scatter_with_mask, -) if TYPE_CHECKING: from tokenspeed.runtime.layers.attention.configs.base import BaseAttnConfig @@ -55,6 +52,8 @@ class MambaForwardMetadata: query_start_loc: torch.Tensor | None mamba_cache_indices: torch.Tensor + mamba_output_indices: Optional[torch.Tensor] = None + mamba_req_pool_indices: Optional[torch.Tensor] = None extend_prefix_lens: Optional[torch.Tensor] = None # Pre-computed src/dst indices for extracting Mamba prefix-cache snapshots. track_ssm_h_src: Optional[torch.Tensor] = None @@ -128,6 +127,7 @@ def __init__( device: str, page_size: int = 1, speculative_num_draft_tokens: int = 0, + max_req_pool_size: int = 0, ): self.size = size self.device = device @@ -135,55 +135,143 @@ def __init__( self.page_size = page_size self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layer_ids)} self.is_kda_cache = False + self.max_req_pool_size = max_req_pool_size + + # Base slots (working + checkpoint) are allocated by C++ scheduler. + # Python-only draft rows live after the scheduler-owned range and are + # addressed by normal row indices in the same tensors. + self.base_size = size + self.speculative_num_draft_tokens = speculative_num_draft_tokens + self.current_input_size = ( + max_req_pool_size + 1 if max_req_pool_size > 0 else size + ) + self.draft_slots_per_req = max(0, speculative_num_draft_tokens - 1) + self.draft_base = size + self.draft_total_slots = self.current_input_size * self.draft_slots_per_req + total_size = size + self.draft_total_slots + self.total_size = total_size - # Allocate conv state: (num_mamba_layers, size+1, conv_dim, state_len) + # Allocate conv state: (num_mamba_layers, total_size, conv_dim, state_len) self.conv_state = torch.zeros( num_mamba_layers, - size + 1, + total_size, *conv_state_shape, dtype=conv_dtype, device=device, ) - # Allocate temporal/SSM state: (num_mamba_layers, size+1, heads, key_dim, val_dim) + # Allocate temporal/SSM state: (num_mamba_layers, total_size, heads, key_dim, val_dim) self.ssm_state = torch.zeros( num_mamba_layers, - size + 1, + total_size, *temporal_state_shape, dtype=ssm_dtype, device=device, ) - # Speculative decoding intermediate caches - if speculative_num_draft_tokens > 0: - self.intermediate_ssm_state_cache = torch.empty( - num_mamba_layers, - size + 1, - speculative_num_draft_tokens, - *temporal_state_shape, - dtype=ssm_dtype, - device=device, - ) - self.intermediate_conv_window_cache = torch.zeros( - num_mamba_layers, - size + 1, - speculative_num_draft_tokens, - *conv_state_shape, - dtype=conv_dtype, - device=device, - ) - self.mamba_cache = ( - self.conv_state, - self.ssm_state, - self.intermediate_ssm_state_cache, - self.intermediate_conv_window_cache, - ) - else: - self.mamba_cache = (self.conv_state, self.ssm_state) + self.mamba_cache = (self.conv_state, self.ssm_state) + + self.current_input_indices = torch.full( + (self.current_input_size,), -1, dtype=torch.int32, device=device + ) def get_mamba_indices(self, mamba_pool_indices: torch.Tensor) -> torch.Tensor: """Return mamba cache indices directly (allocated by C++ scheduler).""" return mamba_pool_indices.to(torch.int32) + def get_mtp_output_indices( + self, + req_pool_indices: torch.Tensor, + working_indices: torch.Tensor, + draft_token_num: int, + out: torch.Tensor | None = None, + ) -> torch.Tensor: + """Build per-request target-verify outputs: [working, draft0, ...].""" + bs = working_indices.shape[0] + if out is not None: + output_indices = out + output_indices.fill_(-1) + else: + output_indices = torch.full( + (bs, draft_token_num), + -1, + dtype=torch.int32, + device=working_indices.device, + ) + if draft_token_num <= 0: + return output_indices + + working = working_indices.to(torch.int32) + valid = working >= 0 + output_indices[:, 0] = torch.where(valid, working, -1) + + if draft_token_num > 1 and self.draft_slots_per_req > 0: + req = req_pool_indices[:bs].to(torch.int32) + steps = torch.arange( + draft_token_num - 1, dtype=torch.int32, device=working.device + ) + draft = ( + self.draft_base + + req[:, None] * self.draft_slots_per_req + + steps[None, :] + ) + output_indices[:, 1:] = torch.where( + valid[:, None] & (req >= 0)[:, None], + draft, + -1, + ) + return output_indices + + def get_current_input_indices( + self, + req_pool_indices: torch.Tensor, + working_indices: torch.Tensor, + cow_src_indices: torch.Tensor | None = None, + ) -> torch.Tensor: + """Return the row each request should read at the start of target verify.""" + req_pool_indices = req_pool_indices[: working_indices.shape[0]].to(torch.int32) + working_indices = working_indices.to(torch.int32) + # Only keep >= 0 checks; upper bounds are guaranteed by the scheduler. + valid = (working_indices >= 0) & (req_pool_indices >= 0) + safe_req_indices = req_pool_indices.clamp(0, self.current_input_size - 1).long() + stored = self.current_input_indices[safe_req_indices] + current = torch.where(valid & (stored >= 0), stored, working_indices) + current = torch.where(valid, current, torch.full_like(current, -1)) + if cow_src_indices is not None: + cow_src_indices = cow_src_indices[: working_indices.shape[0]].to( + torch.int32 + ) + current = torch.where( + (cow_src_indices >= 0) & valid & (current == working_indices), + cow_src_indices, + current, + ) + return current + + def reset_current_inputs( + self, req_pool_indices: torch.Tensor, working_indices: torch.Tensor + ) -> None: + """Mark freshly allocated/reused scheduler slots as canonical.""" + req_pool_indices = req_pool_indices[: working_indices.shape[0]].to(torch.int32) + working_indices = working_indices.to(torch.int32) + self.current_input_indices[req_pool_indices.long()] = working_indices + + def update_current_inputs_after_verify( + self, + req_pool_indices: torch.Tensor, + output_indices: torch.Tensor, + accepted_lengths: torch.Tensor, + ) -> None: + if output_indices is None or output_indices.numel() == 0: + return + n = accepted_lengths.shape[0] + req_pool_indices = req_pool_indices[:n].to(torch.int32) + accepted_lengths = accepted_lengths.clamp( + min=1, max=output_indices.shape[1] + ).to(torch.int32) + rows = torch.arange(n, device=accepted_lengths.device, dtype=torch.long) + selected = output_indices[rows, (accepted_lengths - 1).long()].to(torch.int32) + self.current_input_indices[req_pool_indices.long()] = selected + def get_mamba_params(self, layer_id: int): """Return per-layer cache slices.""" internal_idx = self.mamba_map[layer_id] @@ -222,6 +310,7 @@ def __init__(self, config: BaseAttnConfig): self.query_start_loc_list = [] self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None + self.output_indices_list = [] self.speculative_num_draft_tokens = getattr( config, "speculative_num_draft_tokens", 0 ) @@ -230,6 +319,12 @@ def __init__(self, config: BaseAttnConfig): def set_pool(self, pool: SimpleMambaPool): self.pool = pool + def reset_current_inputs( + self, req_pool_indices: torch.Tensor, working_indices: torch.Tensor + ): + if self.pool is not None: + self.pool.reset_current_inputs(req_pool_indices, working_indices) + def init_forward_metadata( self, bs: int, @@ -255,6 +350,22 @@ def init_forward_metadata( forward_mode.is_decode_or_idle() and self.is_draft and spec_num_tokens > 1 ) + mamba_output_indices = None + if is_target_verify: + draft_token_num = int( + kwargs.get("tokens_per_req", self.speculative_num_draft_tokens) + ) + cow_src_indices = kwargs.get("mamba_cow_src_indices") + mamba_input_indices = self.pool.get_current_input_indices( + req_pool_indices[:bs], mamba_cache_indices, cow_src_indices + ) + mamba_output_indices = self.pool.get_mtp_output_indices( + req_pool_indices[:bs], + mamba_cache_indices, + draft_token_num, + ) + mamba_cache_indices = mamba_input_indices + if forward_mode.is_decode_or_idle() and spec_num_tokens == 1: query_start_loc = torch.arange( 0, bs + 1, dtype=torch.int32, device=self.device @@ -368,6 +479,8 @@ def init_forward_metadata( self.forward_metadata = MambaForwardMetadata( query_start_loc=query_start_loc, mamba_cache_indices=mamba_cache_indices, + mamba_output_indices=mamba_output_indices, + mamba_req_pool_indices=req_pool_indices[:bs], extend_prefix_lens=kwargs.get("extend_prefix_lens"), track_ssm_h_src=track_ssm_h_src, track_ssm_h_dst=track_ssm_h_dst, @@ -437,6 +550,15 @@ def init_cuda_graph_state( self.query_start_loc_list.append( torch.empty((i + 2,), dtype=torch.int32, device=self.device) ) + if self.speculative_num_draft_tokens > 0: + self.output_indices_list.append( + torch.full( + (i + 1, self.speculative_num_draft_tokens), + self.pad_slot_id, + dtype=torch.int32, + device=self.device, + ) + ) self.cached_cuda_graph_decode_query_start_loc = torch.arange( 0, max_num_tokens + 1, dtype=torch.int32, device=self.device ) @@ -484,16 +606,38 @@ def init_forward_metadata_capture_cuda_graph( raise ValueError(f"Invalid forward mode: {forward_mode=}") mamba_pool_indices = kwargs.get("mamba_pool_indices") + # Reuse the pre-allocated [bs]-length buffer as mamba_indices so the + # capture path matches the replay path: zero allocation, single write. + padded_mamba_indices = self.state_indices_list[bs - 1] if mamba_pool_indices is not None: - mamba_indices = self.pool.get_mamba_indices(mamba_pool_indices[:bs]) + padded_mamba_indices[:bs].copy_( + self.pool.get_mamba_indices(mamba_pool_indices[:bs]) + ) else: - mamba_indices = self.pool.get_mamba_indices(req_pool_indices[:bs]) - self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) + padded_mamba_indices[:bs].copy_( + self.pool.get_mamba_indices(req_pool_indices[:bs]) + ) + mamba_output_indices = None + if is_target_verify: + cow_src_indices = kwargs.get("mamba_cow_src_indices") + mamba_input_indices = self.pool.get_current_input_indices( + req_pool_indices[:bs], padded_mamba_indices, cow_src_indices + ) + mamba_output_indices = self.output_indices_list[bs - 1] + self.pool.get_mtp_output_indices( + req_pool_indices[:bs], + padded_mamba_indices, + self.speculative_num_draft_tokens, + out=mamba_output_indices, + ) + padded_mamba_indices.copy_(mamba_input_indices) self._qsl_dirty[bs - 1] = False self._qsl_last_mode[bs - 1] = (forward_mode, spec_num_tokens > 1) self.forward_metadata = MambaForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], + mamba_output_indices=mamba_output_indices, + mamba_req_pool_indices=req_pool_indices[:bs], ) def init_forward_metadata_replay_cuda_graph( @@ -511,14 +655,22 @@ def init_forward_metadata_replay_cuda_graph( real_bs = bs - num_padding req_pool_indices = req_pool_indices[:bs] + + # Reuse the pre-allocated [bs]-length buffer as the padded mamba_indices + # so downstream ops (get_mtp_output_indices, get_current_input_indices) + # see the full-batch shape with padding rows already set to -1. + # Zero extra allocations on this hot path. + padded_mamba_indices = self.state_indices_list[bs - 1] if mamba_pool_indices is not None: - mamba_indices = self.pool.get_mamba_indices(mamba_pool_indices[:real_bs]) + padded_mamba_indices[:real_bs].copy_( + self.pool.get_mamba_indices(mamba_pool_indices[:real_bs]) + ) else: - mamba_indices = self.pool.get_mamba_indices(req_pool_indices[:real_bs]) - - self.state_indices_list[bs - 1][:real_bs].copy_(mamba_indices) + padded_mamba_indices[:real_bs].copy_( + self.pool.get_mamba_indices(req_pool_indices[:real_bs]) + ) if num_padding > 0: - self.state_indices_list[bs - 1][real_bs:].fill_(self.pad_slot_id) + padded_mamba_indices[real_bs:].fill_(-1) spec_num_tokens = num_tokens // bs if bs > 0 else 1 is_target_verify = ( @@ -534,6 +686,22 @@ def init_forward_metadata_replay_cuda_graph( and spec_num_tokens > 1 ) + mamba_output_indices = None + if is_target_verify: + cow_src_indices = kwargs.get("mamba_cow_src_indices") + mamba_input_indices = self.pool.get_current_input_indices( + req_pool_indices, padded_mamba_indices, cow_src_indices + ) + mamba_output_indices = self.output_indices_list[bs - 1] + self.pool.get_mtp_output_indices( + req_pool_indices, + padded_mamba_indices, + self.speculative_num_draft_tokens, + out=mamba_output_indices, + ) + # mamba_input_indices already encodes padding via padded_mamba_indices. + padded_mamba_indices.copy_(mamba_input_indices) + if num_padding == 0: need_copy = self._qsl_dirty[bs - 1] or self._qsl_last_mode[bs - 1] != ( forward_mode, @@ -571,6 +739,8 @@ def init_forward_metadata_replay_cuda_graph( self.forward_metadata = MambaForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], + mamba_output_indices=mamba_output_indices, + mamba_req_pool_indices=req_pool_indices, ) def get_cuda_graph_seq_len_fill_value(self): @@ -714,12 +884,8 @@ def forward_extend( draft_token_num = kwargs.get( "draft_token_num", self.speculative_num_draft_tokens ) - ( - conv_states, - ssm_states, - intermediate_state_cache, - intermediate_conv_window_cache, - ) = self.pool.get_mamba_params(layer_id) + conv_states, ssm_states = self.pool.get_mamba_params(layer_id) + output_indices = self.forward_metadata.mamba_output_indices batch_size = seq_len // draft_token_num mixed_qkv_reshaped = ( @@ -734,13 +900,13 @@ def forward_extend( bias, activation, conv_state_indices=cache_indices[:batch_size], - intermediate_conv_window=intermediate_conv_window_cache, + output_state_indices=output_indices[:batch_size], ) mixed_qkv = ( mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1) ) else: - conv_states, ssm_states, *rest = self.pool.get_mamba_params(layer_id) + conv_states, ssm_states = self.pool.get_mamba_params(layer_id) extend_prefix_lens = kwargs.get("extend_prefix_lens") if extend_prefix_lens is None: extend_prefix_lens = self.forward_metadata.extend_prefix_lens @@ -812,8 +978,7 @@ def forward_extend( softplus_threshold=20.0, # target_verify specific parameters disable_state_update=True, - intermediate_states_buffer=intermediate_state_cache, - cache_steps=draft_token_num, + output_state_indices=self.forward_metadata.mamba_output_indices, ) else: beta = b.sigmoid() @@ -1051,30 +1216,27 @@ def forward_extend( q, k, v, layer, out_cache_loc, token_to_kv_pool, bs, **kwargs ) + def reset_current_inputs(self, *args, **kwargs): + if self.linear_attn_backend is None: + return + if hasattr(self.linear_attn_backend, "reset_current_inputs"): + self.linear_attn_backend.reset_current_inputs(*args, **kwargs) + def update_mamba_state_after_mtp_verify(self, accepted_length, model): - request_number = accepted_length.shape[0] - state_indices_tensor = ( - self.linear_attn_backend.forward_metadata.mamba_cache_indices[ - :request_number - ] + # mamba_cache_indices are input rows during target-verify. The first + # output row is always the scheduler-owned working slot, so use the + # output index table to update the next-round input pointer. + output_indices = self.linear_attn_backend.forward_metadata.mamba_output_indices + if output_indices is None: + return + req_pool_indices = ( + self.linear_attn_backend.forward_metadata.mamba_req_pool_indices ) - mamba_caches = self.linear_attn_backend.pool.get_mamba_params_all_layers() - ( - conv_states, - ssm_states, - intermediate_state_cache, - intermediate_conv_window_cache, - ) = mamba_caches - - fused_mamba_state_scatter_with_mask( - ssm_states, - intermediate_state_cache, - state_indices_tensor, - accepted_length, - ) - fused_mamba_state_scatter_with_mask( - conv_states, - intermediate_conv_window_cache, - state_indices_tensor, + if req_pool_indices is None: + return + request_number = accepted_length.shape[0] + self.linear_attn_backend.pool.update_current_inputs_after_verify( + req_pool_indices[:request_number], + output_indices[:request_number], accepted_length, ) diff --git a/python/tokenspeed/runtime/layers/attention/linear/causal_conv1d.py b/python/tokenspeed/runtime/layers/attention/linear/causal_conv1d.py index 912f2e54a..d7fd151e0 100755 --- a/python/tokenspeed/runtime/layers/attention/linear/causal_conv1d.py +++ b/python/tokenspeed/runtime/layers/attention/linear/causal_conv1d.py @@ -665,6 +665,7 @@ def _causal_conv1d_update_kernel( conv_state_indices_ptr, num_accepted_tokens_ptr, intermediate_conv_window_ptr, + output_state_indices_ptr, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -686,6 +687,8 @@ def _causal_conv1d_update_kernel( stride_inter_step: tl.constexpr, stride_inter_dim: tl.constexpr, stride_inter_win: tl.constexpr, + stride_output_state_indices_seq: tl.constexpr, + stride_output_state_indices_step: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -701,6 +704,7 @@ def _causal_conv1d_update_kernel( USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, SAVE_INTERMEDIATE: tl.constexpr, + HAS_OUTPUT_STATE_INDICES: tl.constexpr, ): # ruff: noqa: E501 idx_seq = tl.program_id(0) @@ -763,7 +767,7 @@ def _causal_conv1d_update_kernel( x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] - if not SAVE_INTERMEDIATE: + if not SAVE_INTERMEDIATE and not HAS_OUTPUT_STATE_INDICES: # STEP 2: update conv_state in place. Speculative verify uses # SAVE_INTERMEDIATE and scatters the accepted intermediate window after # verification, so writing the real conv_state here is both wrong and @@ -919,6 +923,24 @@ def _causal_conv1d_update_kernel( tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) if KERNEL_WIDTH >= 4: tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + if HAS_OUTPUT_STATE_INDICES: + output_state_idx = tl.load( + output_state_indices_ptr + + idx_seq * stride_output_state_indices_seq + + idx_token * stride_output_state_indices_step + ).to(tl.int64) + if output_state_idx >= 0: + output_base = ( + conv_state_ptr + + output_state_idx * stride_conv_state_seq + + idx_feats * stride_conv_state_dim + ) + if KERNEL_WIDTH >= 2: + tl.store(output_base + 0 * stride_conv_state_tok, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(output_base + 1 * stride_conv_state_tok, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(output_base + 2 * stride_conv_state_tok, col2, mask=mask_w) def causal_conv1d_update( @@ -931,6 +953,7 @@ def causal_conv1d_update( conv_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, intermediate_conv_window: torch.Tensor | None = None, + output_state_indices: torch.Tensor | None = None, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, @@ -1003,7 +1026,10 @@ def causal_conv1d_update( stride_state_indices = ( conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) - state_len = width - 1 + (seqlen - 1) # effective state_len needed + if output_state_indices is not None or intermediate_conv_window is not None: + state_len = width - 1 + else: + state_len = width - 1 + (seqlen - 1) # effective state_len needed np2_statelen = triton.next_power_of_2(state_len) def grid(META): @@ -1022,6 +1048,13 @@ def grid(META): ) else: stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + if output_state_indices is not None: + stride_output_state_indices_seq, stride_output_state_indices_step = ( + output_state_indices.stride(0), + output_state_indices.stride(1), + ) + else: + stride_output_state_indices_seq = stride_output_state_indices_step = 0 _causal_conv1d_update_kernel[grid]( # Pointers to matrices @@ -1033,6 +1066,7 @@ def grid(META): conv_state_indices, num_accepted_tokens, intermediate_conv_window if intermediate_conv_window is not None else x, + output_state_indices if output_state_indices is not None else x, out, # Matrix dimensions batch, @@ -1054,6 +1088,8 @@ def grid(META): stride_inter_step, stride_inter_dim, stride_inter_win, + stride_output_state_indices_seq, + stride_output_state_indices_step, stride_o_seq, stride_o_dim, stride_o_token, @@ -1069,6 +1105,7 @@ def grid(META): USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, SAVE_INTERMEDIATE=intermediate_conv_window is not None, + HAS_OUTPUT_STATE_INDICES=output_state_indices is not None, ) if unsqueeze: out = out.squeeze(-1) diff --git a/python/tokenspeed/runtime/layers/attention/linear/fused_sigmoid_gating_recurrent.py b/python/tokenspeed/runtime/layers/attention/linear/fused_sigmoid_gating_recurrent.py index 1f410e0f8..544ea1114 100755 --- a/python/tokenspeed/runtime/layers/attention/linear/fused_sigmoid_gating_recurrent.py +++ b/python/tokenspeed/runtime/layers/attention/linear/fused_sigmoid_gating_recurrent.py @@ -53,6 +53,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel( # Parameters for target_verify support (unused for decode) intermediate_states_buffer, cache_steps, + output_state_indices, retrieve_parent_token_ptr, stride_retrieve_parent_token_seq: tl.constexpr, stride_retrieve_parent_token_token: tl.constexpr, @@ -77,6 +78,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel( # Optional flags for target_verify support (default False for decode) DISABLE_STATE_UPDATE: tl.constexpr = False, CACHE_INTERMEDIATE_STATES: tl.constexpr = False, + HAS_OUTPUT_STATE_INDICES: tl.constexpr = False, HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr = False, ): """ @@ -185,7 +187,18 @@ def fused_sigmoid_gating_delta_rule_update_kernel( tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) # Cache intermediate states if enabled - if CACHE_INTERMEDIATE_STATES: + if HAS_OUTPUT_STATE_INDICES: + out_idx = tl.load(output_state_indices + i_n * T + step_idx).to(tl.int64) + if out_idx >= 0: + output_ptr = ( + h0_source + + out_idx * HV * K * V + + i_hv * K * V + + o_k[:, None] * V + + o_v[None, :] + ) + tl.store(output_ptr, b_h.to(output_ptr.dtype.element_ty), mask=mask_h) + elif CACHE_INTERMEDIATE_STATES: if cache_idx >= 0: step_offset = step_idx * HV * K * V cache_ptr = ( @@ -242,6 +255,7 @@ def fused_sigmoid_gating_delta_rule_update( disable_state_update: bool = False, intermediate_states_buffer: torch.Tensor | None = None, cache_steps: int | None = None, + output_state_indices: torch.Tensor | None = None, retrieve_parent_token: torch.Tensor | None = None, ): """ @@ -302,6 +316,7 @@ def fused_sigmoid_gating_delta_rule_update( cu_seqlens=cu_seqlens, intermediate_states_buffer=intermediate_states_buffer, cache_steps=0 if cache_steps is None else cache_steps, + output_state_indices=output_state_indices, retrieve_parent_token_ptr=retrieve_parent_token, stride_retrieve_parent_token_seq=stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token=stride_retrieve_parent_token_token, @@ -324,6 +339,7 @@ def fused_sigmoid_gating_delta_rule_update( IS_VARLEN=cu_seqlens is not None, DISABLE_STATE_UPDATE=disable_state_update, CACHE_INTERMEDIATE_STATES=intermediate_states_buffer is not None, + HAS_OUTPUT_STATE_INDICES=output_state_indices is not None, HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_parent_token is not None, num_warps=num_warps, num_stages=num_stages, diff --git a/python/tokenspeed/runtime/layers/attention/linear/mamba_state_scatter_triton.py b/python/tokenspeed/runtime/layers/attention/linear/mamba_state_scatter_triton.py index 95d036d63..702171aeb 100644 --- a/python/tokenspeed/runtime/layers/attention/linear/mamba_state_scatter_triton.py +++ b/python/tokenspeed/runtime/layers/attention/linear/mamba_state_scatter_triton.py @@ -132,11 +132,12 @@ def _mamba_state_snapshot_kernel( In-place copy kernel: pool[:, dst[i], :] = pool[:, src[i], :] Skips copy if page_size > 0 and cache_lengths[i] % page_size != 0. - Grid: (num_valid, num_layers, ceil(elem_per_entry / BLOCK_SIZE)) + Grid: (num_valid, num_layers) — loops over elem_per_entry internally. + Invalid entries early-return wasting only 1 block instead of + ceil(elem_per_entry / BLOCK_SIZE) blocks. """ pid_req = tl.program_id(0) pid_layer = tl.program_id(1).to(tl.int64) - pid_block = tl.program_id(2).to(tl.int64) src_idx = tl.load(src_indices_ptr + pid_req).to(tl.int64) dst_idx = tl.load(dst_indices_ptr + pid_req).to(tl.int64) @@ -160,12 +161,11 @@ def _mamba_state_snapshot_kernel( src_offset = pid_layer * layer_stride + src_idx * req_stride dst_offset = pid_layer * layer_stride + dst_idx * req_stride - start = pid_block * BLOCK_SIZE - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < elem_per_entry - - data = tl.load(pool_ptr + src_offset + offsets, mask=mask) - tl.store(pool_ptr + dst_offset + offsets, data, mask=mask) + for start in tl.static_range(0, elem_per_entry, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < elem_per_entry + data = tl.load(pool_ptr + src_offset + offsets, mask=mask) + tl.store(pool_ptr + dst_offset + offsets, data, mask=mask) def fused_mamba_state_snapshot( @@ -224,8 +224,8 @@ def fused_mamba_state_snapshot( cache_lengths = src_indices # unused; kernel skips when page_size==0 page_size = 0 - BLOCK_SIZE = 1024 - grid = (num_valid, num_layers, triton.cdiv(elem_per_entry, BLOCK_SIZE)) + BLOCK_SIZE = 8192 + grid = (num_valid, num_layers) _mamba_state_snapshot_kernel[grid]( pool, @@ -238,6 +238,7 @@ def fused_mamba_state_snapshot( req_stride, pool_size, BLOCK_SIZE=BLOCK_SIZE, + num_warps=8, ) diff --git a/python/tokenspeed/runtime/layers/attention/registry.py b/python/tokenspeed/runtime/layers/attention/registry.py index 294789b1e..bef2d434d 100644 --- a/python/tokenspeed/runtime/layers/attention/registry.py +++ b/python/tokenspeed/runtime/layers/attention/registry.py @@ -285,6 +285,12 @@ def _create_hybrid_linear_attn( if server_args.speculative_algorithm is not None else 0 ), + max_req_pool_size=( + server_args.max_num_seqs + // max( + server_args.data_parallel_size or server_args.mapping.attn.dp_size, 1 + ) + ), ) linear_attn_backend.set_pool(mamba_pool)