Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,16 @@ def _setup_layerwise_loadback(self, execution_plan) -> None:
if host_exec is not None:
self.model_executor.execution_stream.wait_stream(host_exec.write_stream)

def _flush_mamba_retract_states(self, forward_op) -> None:
"""Copy draft->working mamba states when retract occurred (no forward scheduled)."""
if forward_op is not None:
return
if self.model_executor.drafter is None:
return
if self.model_executor.runtime_states.mamba_pool is None:
return
self.model_executor.flush_mamba_draft_to_working_on_retract()
Comment on lines +643 to +651
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Gate retract-state flush on actual retract ops

_flush_mamba_retract_states runs whenever forward_op is None, but it never verifies that the current execution plan actually contains a retraction. In drafter+Mamba mode this means idle/no-forward iterations can repeatedly call flush_mamba_draft_to_working_on_retract() using stale previous-batch buffers, performing unintended state copies unrelated to any retract and potentially racing with cache maintenance on the same slots. Please gate this path on a real retract signal (for example, retraction-specific cache ops or a scheduler flag) instead of forward_op is None alone.

Useful? React with 👍 / 👎.


# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
Expand Down Expand Up @@ -930,6 +940,7 @@ def event_loop(self):
self._submit_cache_ops(execution_plan)

forward_op = self._get_forward_op(execution_plan)
self._flush_mamba_retract_states(forward_op)

stats = self._get_scheduler_stats()
num_iter_tokens = (
Expand Down Expand Up @@ -1044,6 +1055,7 @@ def event_loop_overlap(self):
self._submit_cache_ops(execution_plan)

forward_op = self._get_forward_op(execution_plan)
self._flush_mamba_retract_states(forward_op)

stats = self._get_scheduler_stats()
num_iter_tokens = (
Expand Down Expand Up @@ -1112,12 +1124,6 @@ def event_loop_overlap(self):

curr_results = None
if forward_op is not None:
if forward_op.num_extends() <= 0:
# Overlap dispatch may schedule one extra decode before
# the previous result is committed. Snapshot the completed
# working state before this decode mutates the same slot;
# the snapshot helper only copies block-aligned states.
self.model_executor.snapshot_mamba_checkpoints_for_op(forward_op)
curr_results, _ = self._dispatch_forward(
forward_op,
sampling_params_list,
Expand Down
193 changes: 154 additions & 39 deletions python/tokenspeed/runtime/execution/model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def __init__(
self.execution_stream = torch.cuda.Stream()
self.log_step = 0
self._seen_prefill_ids: set[str] = set()
self._prev_decode_bs: int = 0
self._sentinel_neg1 = torch.tensor(-1, device=self.device, dtype=torch.int64)
# Decode stats — accumulated from synced results (no GPU sync needed)
self.num_generated_tokens = 0
self.num_decode_steps = 0
Expand Down Expand Up @@ -467,48 +469,141 @@ def accumulate_decode_stats(self, results: ModelExecutionResult, bs: int):
self.num_generated_tokens += int(results.output_lengths.sum().item())
self.num_decode_steps += bs

def snapshot_mamba_checkpoints_for_op(self, forward_op) -> None:
"""Snapshot completed decode working states into checkpoint slots."""
if self.runtime_states.mamba_pool is None or forward_op.num_extends() > 0:
def _snapshot_mamba_checkpoints(
self,
accept_lengths: torch.Tensor,
bs: int,
num_extends: int,
) -> None:
"""Snapshot mamba states to checkpoint slots at page boundaries.

Called after ``_update_runtime_state`` on the execution stream so
``valid_cache_lengths`` already reflects the accepted tokens.

Non-MTP (accept_length == 1):
The working slot holds the up-to-date state for the new
cache_length. Pass the kernel page_size so it copies only
when the new length is page-aligned.

MTP (accept_length > 1):
cache_length may jump over a page boundary. The intermediate
state lives in ``mamba_output_indices[req, step]``. Boundary
detection and source-slot selection are done entirely on GPU
with -1 sentinels so the snapshot kernel skips invalid entries
via its bounds check — no GPU-to-CPU sync, preserving
overlap-schedule pipelining.
"""
if self.runtime_states.mamba_pool is None or num_extends > 0:
return
if not getattr(forward_op, "mamba_pool_indices", None):
if not self.input_buffers.has_mamba:
return

# CPU-side pre-filter
src_list = []
dst_list = []
req_list = []
for i in range(len(forward_op.request_ids)):
pool_idx = forward_op.mamba_pool_indices[i]
ckpt_idx = forward_op.mamba_track_pool_indices[i]
if pool_idx != -1 and ckpt_idx != -1:
src_list.append(pool_idx)
dst_list.append(ckpt_idx)
req_list.append(forward_op.request_pool_indices[i])

num_valid = len(src_list)
if num_valid == 0:
req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs]
track_indices = self.input_buffers.mamba_track_pool_indices_buf[:bs]
page_size = self.config.block_size
dev = req_pool_indices.device
sentinel = self._sentinel_neg1

if self.drafter is not None:
# -- MTP path: find the output slot at the crossed boundary --
backend = getattr(
self.attn_backend, "linear_attn_backend", self.attn_backend
)
fm = getattr(backend, "forward_metadata", None)
if fm is None:
return
output_indices = fm.mamba_output_indices
if output_indices is None:
return

new_cl = self.runtime_states.valid_cache_lengths[req_pool_indices]
old_cl = new_cl - accept_lengths[:bs].to(device=dev, dtype=new_cl.dtype)
first_boundary = ((old_cl // page_size) + 1) * page_size

step_raw = first_boundary - old_cl - 1
max_col = output_indices.shape[1] - 1
step = step_raw.clamp(min=0, max=max_col).to(torch.int64)

req_range = torch.arange(bs, device=dev)
src_raw = output_indices[req_range, step].to(torch.int64)
dst_raw = track_indices.to(device=dev, dtype=torch.int64)

invalid = (
(first_boundary > new_cl)
| (dst_raw < 0)
| (src_raw < 0)
| (src_raw == dst_raw)
| (step_raw < 0)
)
src = torch.where(invalid, sentinel, src_raw)
dst = torch.where(invalid, sentinel, dst_raw)

self.runtime_states.snapshot_mamba_checkpoints(
src,
dst,
cache_lengths=None,
page_size=0,
num_valid=bs,
)
else:
# -- Non-MTP path: working slot IS the up-to-date state --
src_raw = self.input_buffers.mamba_pool_indices_buf[:bs].to(
device=dev, dtype=torch.int64
)
dst_raw = track_indices.to(device=dev, dtype=torch.int64)

invalid = (src_raw < 0) | (dst_raw < 0) | (src_raw == dst_raw)
src = torch.where(invalid, sentinel, src_raw)
dst = torch.where(invalid, sentinel, dst_raw)

cache_lengths = self.runtime_states.valid_cache_lengths[req_pool_indices]
self.runtime_states.snapshot_mamba_checkpoints(
src,
dst,
cache_lengths=cache_lengths,
page_size=page_size,
num_valid=bs,
)

def flush_mamba_draft_to_working_on_retract(self) -> None:
"""Copy accepted draft mamba state -> working slot for all previous-batch requests.

Called from event_loop when retract WriteBackOps are detected.
Uses the previous decode iteration's input_buffers (still valid since
no new forward has overwritten them).
Runs on execution_stream to respect ordering with previous forward writes.
"""
bs = self._prev_decode_bs
if bs <= 0:
return

t_src = torch.tensor(src_list, dtype=torch.int64, device="cpu", pin_memory=True)
t_dst = torch.tensor(dst_list, dtype=torch.int64, device="cpu", pin_memory=True)
t_req = torch.tensor(req_list, dtype=torch.int64, device="cpu", pin_memory=True)

src_buf = torch.empty(num_valid, dtype=torch.int64, device=self.device)
dst_buf = torch.empty(num_valid, dtype=torch.int64, device=self.device)
req_buf = torch.empty(num_valid, dtype=torch.int64, device=self.device)
src_buf.copy_(t_src, non_blocking=True)
dst_buf.copy_(t_dst, non_blocking=True)
req_buf.copy_(t_req, non_blocking=True)

cache_lengths = self.runtime_states.valid_cache_lengths[req_buf]
self.runtime_states.snapshot_mamba_checkpoints(
src_buf,
dst_buf,
cache_lengths,
self.config.block_size,
num_valid,
)
backend = getattr(self.attn_backend, "linear_attn_backend", self.attn_backend)
pool = getattr(backend, "pool", None)
if pool is None:
return

sentinel = self._sentinel_neg1

with torch.cuda.stream(self.execution_stream):
req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs]
working = self.input_buffers.mamba_pool_indices_buf[:bs]

src_raw = pool.current_input_indices[req_pool_indices.clamp(0).long()].to(
dtype=torch.int64
)
dst_raw = working.to(dtype=torch.int64)

invalid = (src_raw < 0) | (dst_raw < 0) | (src_raw == dst_raw)
src = torch.where(invalid, sentinel, src_raw)
dst = torch.where(invalid, sentinel, dst_raw)

self.runtime_states.snapshot_mamba_checkpoints(
src,
dst,
cache_lengths=None,
page_size=0,
num_valid=bs,
)

def execute_forward_op_with_log(
self,
Expand Down Expand Up @@ -822,6 +917,9 @@ def execute_forward_op(
bs = len(forward_op.request_ids)
forward_mode = ForwardMode.from_num_extends(num_extends, bs)

if num_extends <= 0:
self._prev_decode_bs = bs

if self.runtime_states.mamba_pool is not None and (
num_extends > 0 or has_retract
):
Expand All @@ -834,9 +932,21 @@ def execute_forward_op(
self.runtime_states.zero_mamba_states(
mamba_pool_indices,
mamba_cow_src,
self.input_buffers.extend_prefix_lens_buf[:bs],
bs,
self.input_buffers.extend_prefix_lens_buf[:num_extends],
num_extends,
)
if hasattr(self.attn_backend, "reset_current_inputs"):
self.attn_backend.reset_current_inputs(
self.input_buffers.req_pool_indices_buf[:num_extends],
mamba_pool_indices[:num_extends],
)
elif has_retract:
if hasattr(self.attn_backend, "reset_current_inputs"):
retract_mask = mamba_cow_src[:bs] >= 0
self.attn_backend.reset_current_inputs(
self.input_buffers.req_pool_indices_buf[:bs][retract_mask],
mamba_pool_indices[:bs][retract_mask],
)

grammar_completion = None

Expand Down Expand Up @@ -967,6 +1077,11 @@ def execute_forward_op(
input_lengths=self.input_buffers.input_lengths_buf[:bs],
is_extend=num_extends > 0,
)
self._snapshot_mamba_checkpoints(
output_lengths,
bs,
num_extends,
)

with nvtx_range("output_d2h", color="green"):
output_tokens = output_tokens.to("cpu", non_blocking=True)
Expand Down
Loading
Loading