diff --git a/docs/serving/deepseek-v4.md b/docs/serving/deepseek-v4.md index 68be87341..b984a50a8 100644 --- a/docs/serving/deepseek-v4.md +++ b/docs/serving/deepseek-v4.md @@ -6,7 +6,7 @@ DP=4 + expert parallel + mega_moe + FP8 KV cache (B200, 4× SM100): ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \ - --host localhost --port 30100 \ + --host localhost --port 8000 \ --dist-init-addr 127.0.0.1:4013 \ --trust-remote-code \ --data-parallel-size 4 \ @@ -68,7 +68,7 @@ GSM8K 5-shot, 50 samples is the standard quick-validation harness for V4: ```bash HF_DATASETS_TRUST_REMOTE_CODE=1 lm_eval run \ --model local-completions \ - --model_args "model=deepseek-ai/DeepSeek-V4-Flash,base_url=http://127.0.0.1:30100/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=4,max_retries=1,timeout=600,max_gen_toks=256" \ + --model_args "model=deepseek-ai/DeepSeek-V4-Flash,base_url=http://127.0.0.1:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=4,max_retries=1,timeout=600,max_gen_toks=256" \ --tasks gsm8k --num_fewshot 5 --limit 50 --batch_size 1 \ --gen_kwargs temperature=0 ``` diff --git a/python/tokenspeed/runtime/engine/schedule_batch.py b/python/tokenspeed/runtime/engine/schedule_batch.py index ae96fd961..6aa384d30 100755 --- a/python/tokenspeed/runtime/engine/schedule_batch.py +++ b/python/tokenspeed/runtime/engine/schedule_batch.py @@ -269,10 +269,12 @@ def prealloc_for_draft_decode(self, is_disaggregation_decode: bool = False): return out_cache_loc = torch.concat(out_cache_loc_list) out_cache_loc = out_cache_loc.to(self.device, non_blocking=True) - req_indices = torch.tensor(req_indices, dtype=torch.int32).to( + req_indices = torch.tensor(req_indices, dtype=torch.int64).to( self.device, non_blocking=True ) - start_offsets = self.req_to_token_pool.alloced_lens[req_indices] + start_offsets = torch.index_select( + self.req_to_token_pool.alloced_lens, 0, req_indices + ) end_offsets = start_offsets + num_tokens_pre_alloc assign_req_to_token_pool[(bs,)]( req_indices, diff --git a/python/tokenspeed/runtime/execution/cache_loc_kernel.py b/python/tokenspeed/runtime/execution/cache_loc_kernel.py index acc6ef74f..14552ed26 100644 --- a/python/tokenspeed/runtime/execution/cache_loc_kernel.py +++ b/python/tokenspeed/runtime/execution/cache_loc_kernel.py @@ -128,13 +128,14 @@ def update_req_to_page( def compute_out_cache_loc_kernel( # Input pointers req_pool_indices_ptr, # [batch_size] - input_lengths_ptr, # [batch_size] - valid_cache_lengths_ptr, # [req_pool_size+1] + input_lengths_ptr, # [batch_size] or None for uniform mode + cache_start_ptr, # [batch_size] req_to_pages_ptr, # [req_pool_size+1, max_pages] - cumsum_lengths_ptr, # [batch_size] - cumulative sum of input_lengths for output offset + cumsum_lengths_ptr, # [batch_size] or None for uniform mode # Output pointer out_cache_loc_ptr, # [total_tokens] # Scalars + uniform_input_length, # used when input_lengths_ptr is None page_size: tl.constexpr, max_pages: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -143,7 +144,7 @@ def compute_out_cache_loc_kernel( Unified kernel to compute out_cache_loc for both prefill and decode. For each token in each request, compute: - position = valid_cache_length + token_offset_in_seq + position = cache_start[req_idx] + token_offset_in_seq page_idx = position // page_size offset_in_page = position % page_size page_id = req_to_pages[req_pool_idx, page_idx] @@ -151,30 +152,39 @@ def compute_out_cache_loc_kernel( For decode, input_lengths are all 1. For prefill, input_lengths vary. + + When all requests share the same input_length (the multi-step drafter + case), callers pass ``input_lengths_ptr=None`` (and ``cumsum_lengths_ptr=None``) + together with ``uniform_input_length`` set to the shared length. Triton + specializes the kernel on the None-ness of the pointers at JIT time and + dead-code-eliminates the corresponding GMEM reads. """ # Program ID represents which request we're processing req_idx = tl.program_id(0) - # Load request metadata + # Load request metadata. req_pool_idx = tl.load(req_pool_indices_ptr + req_idx) - seq_len = tl.load(input_lengths_ptr + req_idx) - valid_cache_len = tl.load(valid_cache_lengths_ptr + req_pool_idx) - - # Get output offset from cumsum - # Always load from cumsum, use 0 index for first request to ensure type consistency - offset_idx = tl.where(req_idx > 0, req_idx - 1, 0) - output_offset = tl.load(cumsum_lengths_ptr + offset_idx) - # Zero out offset for first request - output_offset = tl.where(req_idx > 0, output_offset, 0) + valid_cache_len = tl.load(cache_start_ptr + req_idx) + + if input_lengths_ptr is not None: + input_length = tl.load(input_lengths_ptr + req_idx) + # Always load from cumsum, use 0 index for first request to ensure type consistency + offset_idx = tl.where(req_idx > 0, req_idx - 1, 0) + output_offset = tl.load(cumsum_lengths_ptr + offset_idx) + # Zero out offset for first request + output_offset = tl.where(req_idx > 0, output_offset, 0) + else: + input_length = uniform_input_length + output_offset = req_idx * uniform_input_length # Process tokens in blocks - num_blocks = tl.cdiv(seq_len, BLOCK_SIZE) + num_blocks = tl.cdiv(input_length, BLOCK_SIZE) for block_idx in range(num_blocks): block_start = block_idx * BLOCK_SIZE # Compute token offsets within this block token_offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = token_offsets < seq_len + mask = token_offsets < input_length # Compute logical positions positions = valid_cache_len + token_offsets @@ -200,10 +210,10 @@ def compute_out_cache_loc( out_cache_loc_ptr, req_pool_indices: torch.Tensor, # [batch_size] input_lengths: torch.Tensor, # [batch_size] - valid_cache_lengths: torch.Tensor, # [req_pool_size+1] + cache_start: torch.Tensor, # [batch_size] req_to_pages: torch.Tensor, # [req_pool_size+1, max_pages] page_size: int, -) -> torch.Tensor: +) -> None: batch_size = req_pool_indices.shape[0] max_pages = req_to_pages.shape[1] @@ -215,10 +225,45 @@ def compute_out_cache_loc( compute_out_cache_loc_kernel[grid]( req_pool_indices, input_lengths, - valid_cache_lengths, + cache_start, req_to_pages, cumsum_lengths, out_cache_loc_ptr, + 0, # uniform_input_length unused when input_lengths_ptr is not None + page_size=page_size, + max_pages=max_pages, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def compute_out_cache_loc_uniform( + out_cache_loc_ptr, + req_pool_indices: torch.Tensor, # [batch_size] + uniform_input_length: int, + cache_start: torch.Tensor, # [batch_size] + req_to_pages: torch.Tensor, # [req_pool_size+1, max_pages] + page_size: int, +) -> None: + """Specialized entry point when every request has the same ``input_length``. + + Skips the per-call ``torch.full`` + ``cumsum`` host-side work and the + corresponding GMEM reads inside the kernel. Used by the multi-step drafter + where each request decodes exactly ``spec_num_steps - 1`` tokens. + """ + batch_size = req_pool_indices.shape[0] + max_pages = req_to_pages.shape[1] + + BLOCK_SIZE = 128 + grid = (batch_size,) + + compute_out_cache_loc_kernel[grid]( + req_pool_indices, + None, # input_lengths_ptr is None → kernel uses uniform_input_length + cache_start, + req_to_pages, + None, # cumsum_lengths_ptr is None → kernel computes offset analytically + out_cache_loc_ptr, + uniform_input_length, page_size=page_size, max_pages=max_pages, BLOCK_SIZE=BLOCK_SIZE, @@ -261,7 +306,7 @@ def flatten_and_to_device(data, dtype=torch.int32): forward_op.new_occupied_pages, dtype=torch.int32 ) request_pool_indices = flatten_and_to_device( - forward_op.request_pool_indices, dtype=torch.int32 + forward_op.request_pool_indices, dtype=torch.int64 ) update_req_to_page( req_to_page=req_to_page, diff --git a/python/tokenspeed/runtime/execution/context.py b/python/tokenspeed/runtime/execution/context.py index baf68d401..e4fe3f922 100644 --- a/python/tokenspeed/runtime/execution/context.py +++ b/python/tokenspeed/runtime/execution/context.py @@ -59,3 +59,4 @@ class ForwardContext: # --- logits processor --- keep_full_logits: bool = False + last_index_offsets: torch.Tensor | None = None diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index 4402c79eb..e637607ae 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -246,13 +246,13 @@ def __init__( try: draft_attn_backend.init_cuda_graph_state( self.max_bs, - self.drafter.draft_seq_lens, + self.drafter.draft_seq_lens_buf, paged_cache_group_specs=draft_paged_cache_group_specs, max_tokens_per_req=self.max_tokens_per_req, ) except TypeError: draft_attn_backend.init_cuda_graph_state( - self.max_bs, self.drafter.draft_seq_lens + self.max_bs, self.drafter.draft_seq_lens_buf ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index b8fc18b09..110fb0afc 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -26,7 +26,9 @@ import torch from typing_extensions import override -from tokenspeed.runtime.execution.cache_loc_kernel import compute_out_cache_loc +from tokenspeed.runtime.execution.cache_loc_kernel import ( + compute_out_cache_loc_uniform, +) from tokenspeed.runtime.execution.context import ForwardContext from tokenspeed.runtime.execution.drafter.base import BaseDrafter from tokenspeed.runtime.execution.forward_batch_info import ( @@ -103,14 +105,34 @@ def __init__( self.dp_size = draft_model_runner.mapping.attn.dp_size self.world_size = draft_model_runner.mapping.world_size - # Pool-indexed scratch for compute_out_cache_loc. - self.draft_seq_lens_pool = torch.zeros_like( - self.runtime_states.valid_cache_lengths - ) - # Drafter-owned alias source for the draft attn backend; advanced in # place during multi-step decode. - self.draft_seq_lens = torch.zeros_like(self.input_buffers.seq_lens_buf) + self.draft_seq_lens_buf = torch.zeros_like(self.input_buffers.seq_lens_buf) + + # Persistent output buffer for the draft step's compute_out_cache_loc. + self.draft_out_cache_loc_buf = torch.empty( + (self.input_buffers.max_bs * (spec_num_steps - 1),), + dtype=torch.int32, + device=self.device, + ) + + # Per-request input length is always 1 in multi-step decode (one token per request). + self.draft_input_lengths_buf = torch.ones( + (self.input_buffers.max_bs,), + dtype=torch.int32, + device=self.device, + ) + + # Precomputed `arange(max_bs) * spec_num_tokens - 1`, sliced and passed + # via ForwardContext for the padded-static-len last-token selection in + # LogitsProcessor. + self.last_index_offsets_buf = ( + torch.arange( + self.input_buffers.max_bs, dtype=torch.int64, device=self.device + ) + * spec_num_tokens + - 1 + ) # ------------------------------------------------------------------ # Internal helpers @@ -120,31 +142,6 @@ def _map_hot(self, ids: torch.Tensor) -> torch.Tensor: """Map token ids through hot_token_ids if available, otherwise return as-is.""" return self.hot_token_ids[ids] if self.hot_token_ids is not None else ids - def _compute_draft_cache_locs( - self, - bs: int, - req_pool_indices: torch.Tensor, - cache_start: torch.Tensor, - ) -> torch.Tensor: - """Write slots for steps 1..N-1; shape (bs, spec_num_steps - 1).""" - out_cache_locs = torch.empty( - (bs * (self.spec_num_steps - 1),), dtype=torch.int32, device=self.device - ) - # Scatter cache_start into the pool-indexed buffer. - self.draft_seq_lens_pool.zero_() - self.draft_seq_lens_pool[req_pool_indices] = cache_start - compute_out_cache_loc( - out_cache_loc_ptr=out_cache_locs, - req_pool_indices=req_pool_indices, - input_lengths=torch.full( - (bs,), self.spec_num_steps - 1, device=self.device - ), - valid_cache_lengths=self.draft_seq_lens_pool, - req_to_pages=self.req_to_page, - page_size=self.page_size, - ) - return out_cache_locs.view(bs, self.spec_num_steps - 1) - def _get_first_step_input( self, forward_mode: ForwardMode, @@ -209,6 +206,9 @@ def _run_first_step( forward_mode=draft_first_mode, capture_hidden_mode=CaptureHiddenMode.LAST, padded_static_len=self.spec_num_tokens if is_decode_like else -1, + last_index_offsets=( + self.last_index_offsets_buf[:bs] if is_decode_like else None + ), keep_full_logits=False, global_num_tokens=draft_input.global_num_tokens, global_bs=draft_input.global_bs, @@ -229,7 +229,7 @@ def _run_multi_step_decode( self, bs: int, draft_ids: torch.Tensor, - draft_tokens: torch.Tensor, + next_tokens: torch.Tensor, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput, ) -> None: @@ -240,19 +240,31 @@ def _run_multi_step_decode( # vc+input_lengths (= seq_lens_buf). if draft_input.forward_mode.is_target_verify(): cache_start = ( - self.runtime_states.valid_cache_lengths[req_pool_indices] + self.runtime_states.valid_cache_lengths.index_select( + 0, req_pool_indices + ) + draft_input.accept_lengths ) else: cache_start = self.input_buffers.seq_lens_buf[:bs].clone() - cache_locs = self._compute_draft_cache_locs(bs, req_pool_indices, cache_start) + # Write cache slots for steps 1..N-1. + cache_locs = self.draft_out_cache_loc_buf[: bs * (self.spec_num_steps - 1)] + compute_out_cache_loc_uniform( + out_cache_loc_ptr=cache_locs, + req_pool_indices=req_pool_indices, + uniform_input_length=self.spec_num_steps - 1, + cache_start=cache_start, + req_to_pages=self.req_to_page, + page_size=self.page_size, + ) + cache_locs_trans = cache_locs.view(bs, self.spec_num_steps - 1).t().contiguous() # +1 is the kernel's read-inclusive convention; advanced per iter. - draft_seq_lens = self.draft_seq_lens[:bs] - draft_seq_lens.copy_(cache_start + 1) + draft_seq_lens = self.draft_seq_lens_buf[:bs] + torch.add(cache_start, 1, out=draft_seq_lens) - input_lengths = torch.ones((bs,), device=self.device, dtype=torch.int32) + input_lengths = self.draft_input_lengths_buf[:bs] positions = cache_start.clone() for i in range(1, self.spec_num_steps): @@ -284,36 +296,23 @@ def _run_multi_step_decode( all_decode_or_idle=draft_input.all_decode_or_idle, ) - out_cache_loc = cache_locs[:, i - 1].contiguous() - with nvtx_range("draft_forward", color="red"): logits_output = self.draft_model_runner.forward( ctx=ctx, input_ids=self._map_hot(draft_ids), positions=positions, - out_cache_loc=out_cache_loc, + out_cache_loc=cache_locs_trans[i - 1], input_lengths=input_lengths, captured_hidden_states=logits_output.hidden_states, ) with nvtx_range("draft_sample", color="yellow"): draft_ids = torch.argmax(logits_output.next_token_logits, dim=-1) - draft_tokens[:, i] = self._map_hot(draft_ids) - positions.add_(1) - draft_seq_lens.add_(1) - - def _get_last_verified_ids( - self, bs: int, forward_mode: ForwardMode, draft_input: EagleDraftInput - ) -> torch.Tensor: - - if forward_mode == ForwardMode.EXTEND: - # Last verified id is simply the base output for each request - return draft_input.base_model_output[:bs] - else: - # Pick the last accepted token per request from the flattened base output - req_offsets = torch.arange(bs, device=self.device) * self.spec_num_tokens - indices = req_offsets + draft_input.accept_lengths - 1 - return draft_input.base_model_output[indices] + # Column 0 holds last_verified_ids; drafter writes step `i` into column `i + 1`. + next_tokens[:, i + 1] = self._map_hot(draft_ids) + if i + 1 < self.spec_num_steps: + positions.add_(1) + draft_seq_lens.add_(1) # ------------------------------------------------------------------ # Public entry point (type-based dispatch from ModelExecutor) @@ -343,43 +342,33 @@ def draft( bs = draft_input.accept_lengths.shape[0] - draft_tokens = torch.empty( - (bs, self.spec_num_steps), + # Layout: column 0 holds the last verified id (the base model's accepted token); + # columns 1..spec_num_steps hold the drafter's speculative tokens. + next_tokens = torch.empty( + (bs, self.spec_num_steps + 1), dtype=torch.int32, device=self.device, ) - # Seed the draft attn backend's aliased seq_lens for the first step. - self.draft_seq_lens[:bs].copy_(self.input_buffers.seq_lens_buf[:bs]) - - # First draft step. - logits_output = self._run_first_step(bs, draft_input) - - # In decode mode the draft model processes spec_num_tokens tokens - # per request (padded). The logits processor returns logits for ALL - # tokens. Select only the last valid token per request. - logits = logits_output.next_token_logits - - if logits.shape[0] != bs and ( - draft_input.forward_mode.is_decode_or_idle() - or draft_input.forward_mode.is_target_verify() - or draft_input.forward_mode.is_draft_extend() - ): - # logits shape: [bs * spec_num_tokens, vocab] - # Select last token per request using accept_lengths - last_indices = ( - torch.arange(bs, device=logits.device) * self.spec_num_tokens - + draft_input.accept_lengths - - 1 + # Last verified id per request → next_tokens[:, 0]. + if draft_input.forward_mode == ForwardMode.EXTEND: + next_tokens[:, 0] = draft_input.base_model_output[:bs] + else: + indices = self.last_index_offsets_buf[:bs] + draft_input.accept_lengths + torch.index_select( + draft_input.base_model_output, 0, indices, out=next_tokens[:, 0] ) - logits_output.next_token_logits = logits[last_indices] + # Seed the draft attn backend's aliased seq_lens for the first step. + self.draft_seq_lens_buf[:bs].copy_(self.input_buffers.seq_lens_buf[:bs]) - if logits_output.hidden_states is not None: - logits_output.hidden_states = logits_output.hidden_states[last_indices] + # First draft step. LogitsProcessor prunes `[bs * spec_num_tokens, ...]` + # down to `[bs, ...]` via padded_static_len, so logits/hidden_states + # arrive here already aligned to one row per request. + logits_output = self._run_first_step(bs, draft_input) draft_ids = torch.argmax(logits_output.next_token_logits, dim=-1) - draft_tokens[:, 0] = self._map_hot(draft_ids) + next_tokens[:, 1] = self._map_hot(draft_ids) # Draft step 2+ (multi-step decode). if self.spec_num_steps > 1: @@ -393,10 +382,10 @@ def draft( skip = self.dp_size == 1 and self.input_buffers.all_extends_mid_chunk if not skip: self._run_multi_step_decode( - bs, draft_ids, draft_tokens, logits_output, draft_input + bs, draft_ids, next_tokens, logits_output, draft_input ) - return draft_tokens + return next_tokens @override @nvtx_range("drafter", color="purple") @@ -419,10 +408,5 @@ def run( all_decode_or_idle=base_ctx.all_decode_or_idle, ) - draft_tokens = self.draft(draft_input) - - last_verified_ids = self._get_last_verified_ids( - base_ctx.bs, base_ctx.forward_mode, draft_input - ) - - return torch.cat([last_verified_ids.unsqueeze(1), draft_tokens], dim=1) + # next_tokens layout: column 0 = last verified id, columns 1.. = drafter tokens. + return self.draft(draft_input) diff --git a/python/tokenspeed/runtime/execution/input_buffer.py b/python/tokenspeed/runtime/execution/input_buffer.py index 88feca724..2bae604ec 100644 --- a/python/tokenspeed/runtime/execution/input_buffer.py +++ b/python/tokenspeed/runtime/execution/input_buffer.py @@ -67,7 +67,7 @@ def __init__( self.shifted_prefill_ids_buf = torch.ones_like(self.input_ids_buf) self.input_lengths_buf = torch.ones((max_num_tokens,), dtype=torch.int32) self.positions_buf = torch.arange(0, max_num_tokens, dtype=torch.int64) - self.req_pool_indices_buf = torch.zeros((max_bs,), dtype=torch.int32) + self.req_pool_indices_buf = torch.zeros((max_bs,), dtype=torch.int64) self.seq_lens_buf = torch.ones((max_bs,), dtype=torch.int32) # Initialise to dummy_kv_slot so that padding positions (never # written by compute_out_cache_loc) always point to the reserved @@ -165,28 +165,29 @@ def fill_input_buffers( req_pool_indices_device = self.req_pool_indices_buf[:batch_size] input_lengths_device = self.input_lengths_buf[:batch_size] + valid_cache_lengths = runtime_states.valid_cache_lengths.index_select( + 0, req_pool_indices_device + ) + # Compute out_cache_loc using Triton kernel compute_out_cache_loc( out_cache_loc_ptr=self.out_cache_loc_buf[:total_tokens], req_pool_indices=req_pool_indices_device, input_lengths=input_lengths_device, - valid_cache_lengths=runtime_states.valid_cache_lengths, + cache_start=valid_cache_lengths, req_to_pages=req_to_page, page_size=self.page_size, ) - cached_prefix_lens = runtime_states.valid_cache_lengths[ - self.req_pool_indices_buf[:batch_size] - ] # Compute positions. In mixed batches, prefill rows use their extend # prefix lengths while decode rows use the current valid cache lengths. prefill_prefix_lens = self.extend_prefix_lens_buf[:num_extends] if num_extends == 0: - prefix_lens = cached_prefix_lens + prefix_lens = valid_cache_lengths elif num_extends == batch_size: prefix_lens = prefill_prefix_lens else: - prefix_lens = cached_prefix_lens.clone() + prefix_lens = valid_cache_lengths.clone() prefix_lens[:num_extends].copy_(prefill_prefix_lens) positions, _ = compute_position_triton( extend_prefix_lens=prefix_lens, @@ -269,7 +270,7 @@ def fill_input_buffers( non_blocking=True, ) - self.seq_lens_buf[:batch_size].copy_(input_lengths_device + cached_prefix_lens) + self.seq_lens_buf[:batch_size].copy_(input_lengths_device + valid_cache_lengths) # Reset positions beyond total_tokens to the dummy KV slot so that any # CUDA graph replay with a larger (padded) batch size writes padding diff --git a/python/tokenspeed/runtime/execution/runtime_states.py b/python/tokenspeed/runtime/execution/runtime_states.py index 5ed6a903c..448373fbb 100644 --- a/python/tokenspeed/runtime/execution/runtime_states.py +++ b/python/tokenspeed/runtime/execution/runtime_states.py @@ -62,9 +62,7 @@ def __init__( def update_valid_cache_length( self, req_pool_indices: torch.Tensor, increment_lengths: torch.Tensor ) -> None: - self.valid_cache_lengths.index_add_( - 0, req_pool_indices.long(), increment_lengths - ) + self.valid_cache_lengths.index_add_(0, req_pool_indices, increment_lengths) def reset_states( self, diff --git a/python/tokenspeed/runtime/layers/attention/backends/trtllm.py b/python/tokenspeed/runtime/layers/attention/backends/trtllm.py index 3bae9734c..9137bdb92 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/trtllm.py +++ b/python/tokenspeed/runtime/layers/attention/backends/trtllm.py @@ -615,7 +615,7 @@ def init_forward_metadata_replay_cuda_graph( torch.index_select( req_to_page[:, : self.max_num_pages], 0, - req_pool_indices[:bs].long(), + req_pool_indices[:bs], out=self.cuda_graph_page_table[:bs, : self.max_num_pages], ) diff --git a/python/tokenspeed/runtime/layers/logits_processor.py b/python/tokenspeed/runtime/layers/logits_processor.py index a879c63fc..86cc02d36 100755 --- a/python/tokenspeed/runtime/layers/logits_processor.py +++ b/python/tokenspeed/runtime/layers/logits_processor.py @@ -35,7 +35,6 @@ ) from tokenspeed.runtime.layers.vocab_parallel_embedding import VocabParallelEmbedding from tokenspeed.runtime.utils import get_colorful_logger -from tokenspeed.runtime.utils.env import global_server_args_dict logger = get_colorful_logger(__name__) @@ -107,6 +106,7 @@ class LogitsMetadata: # for padding padded_static_len: int = -1 + last_index_offsets: torch.Tensor | None = None @classmethod def from_forward_context( @@ -119,6 +119,7 @@ def from_forward_context( capture_hidden_mode=ctx.capture_hidden_mode, extend_seq_lens=input_lengths, padded_static_len=ctx.padded_static_len, + last_index_offsets=ctx.last_index_offsets, ) @@ -231,14 +232,8 @@ def forward( # If padding_static length is 5 and extended_seq_lens is [2, 3], # then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p] # and this retrieves t01 and t12, which are the valid last tokens - idx = torch.arange( - len(logits_metadata.extend_seq_lens), - device=logits_metadata.extend_seq_lens.device, - ) last_index = ( - idx * logits_metadata.padded_static_len - + logits_metadata.extend_seq_lens - - 1 + logits_metadata.last_index_offsets + logits_metadata.extend_seq_lens ) pruned_states = hidden_states[last_index] if aux_hidden_states is not None: diff --git a/python/tokenspeed/runtime/sampling/backends/flashinfer_full.py b/python/tokenspeed/runtime/sampling/backends/flashinfer_full.py index fbabd6734..1d7eb8fc9 100644 --- a/python/tokenspeed/runtime/sampling/backends/flashinfer_full.py +++ b/python/tokenspeed/runtime/sampling/backends/flashinfer_full.py @@ -197,7 +197,7 @@ def _apply_penalties_and_bias( align with flat logits. """ - pool_idx = sampling_info.req_pool_indices.long() + pool_idx = sampling_info.req_pool_indices if num_tokens_per_req > 1: @@ -246,7 +246,7 @@ def _accumulate_counts( weights is int32; 0 masks invalid rows, 1 accumulates.""" self._counts.index_put_( - (pool_idx.long(), tokens.long()), + (pool_idx, tokens.long()), weights.to(torch.int32), accumulate=True, ) @@ -320,7 +320,7 @@ def sample( if raw_logprobs is not None: logits_output.next_token_logprobs = raw_logprobs.gather( - -1, sampled.long().unsqueeze(-1) + -1, sampled.unsqueeze(-1) ).squeeze(-1) # Accumulate sampled tokens into counts (greedy path accumulates too @@ -440,8 +440,7 @@ def verify( accepted_tokens = predict.long().gather(0, safe_positions.view(-1)) pool_idx_expanded = ( - sampling_info.req_pool_indices.long() - .unsqueeze(-1) + sampling_info.req_pool_indices.unsqueeze(-1) .expand(-1, num_tokens_per_req) .reshape(-1) ) @@ -455,7 +454,7 @@ def verify( if raw_logprobs is not None: logits_output.next_token_logprobs = raw_logprobs.gather( - -1, predict.long().unsqueeze(-1) + -1, predict.unsqueeze(-1) ).squeeze(-1) return predict, accept_length diff --git a/python/tokenspeed/runtime/sampling/sampling_batch_info.py b/python/tokenspeed/runtime/sampling/sampling_batch_info.py index c702a14ee..5e8cac0e7 100755 --- a/python/tokenspeed/runtime/sampling/sampling_batch_info.py +++ b/python/tokenspeed/runtime/sampling/sampling_batch_info.py @@ -62,7 +62,7 @@ class SamplingBatchInfo: # An event used for overlap schedule sampling_info_done: threading.Event | None = None - # int32[bs] — req_pool_idx per batch row. Sampling backends gather + # int64[bs] — req_pool_idx per batch row. Sampling backends gather # their pool-indexed scalar buffers (temperature / top_k / top_p / # seeds / penalties / logit_bias / counts) against this index. req_pool_indices: torch.Tensor | None = None diff --git a/python/tokenspeed/runtime/sampling/utils.py b/python/tokenspeed/runtime/sampling/utils.py index 922343ead..44e7e6103 100644 --- a/python/tokenspeed/runtime/sampling/utils.py +++ b/python/tokenspeed/runtime/sampling/utils.py @@ -76,7 +76,7 @@ def write_output_logprobs( """Fill logits_output.next_token_logprobs; callers gate on the enable flag.""" raw_logprobs = torch.log_softmax(logits, dim=-1) logits_output.next_token_logprobs = raw_logprobs.gather( - -1, tokens.long().unsqueeze(-1) + -1, tokens.unsqueeze(-1) ).squeeze(-1)