diff --git a/docs/serving/deepseek-v4.md b/docs/serving/deepseek-v4.md index 0eeada139..68be87341 100644 --- a/docs/serving/deepseek-v4.md +++ b/docs/serving/deepseek-v4.md @@ -17,6 +17,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \ --max-model-len 4096 \ --max-total-tokens 16384 \ --chunked-prefill-size 8192 \ + --enable-mixed-batch \ --gpu-memory-utilization 0.9 \ --disable-kvstore ``` @@ -29,6 +30,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \ | `--kv-cache-dtype fp8_e4m3` | V4 SWA cache rows are uint8-packed FP8 NoPE + BF16 RoPE + UE8M0 scale; FP8 e4m3 is the only supported KV dtype. | | `--moe-backend mega_moe` | Activates the DeepGEMM `fp8_fp4_mega_moe` fused experts. Requires `tokenspeed-deepgemm>=2.5.0.post20260424`. | | `--attention-use-fp4-indexer-cache` | Stores indexer keys as MXFP4 (`[values \| ue8m0 scales]`); the FP8 fallback path is reference-only. | +| `--enable-mixed-batch` | Enables mixed prefill/decode scheduling for V4 sparse attention. It is off by default globally because other backend paths do not all support mixed batches yet. | | `--trust-remote-code` | The HF config uses model-class architectures registered via remote code. | ## Block size diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index cccfe1e84..a0547e7b8 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -272,6 +272,9 @@ def __init__( f"(ratio={server_args.mamba_full_memory_ratio})." ) + enable_mixed_prefill_decode = ( + server_args.enable_mixed_batch and server_args.speculative_algorithm is None + ) scheduler_cfg = make_config( num_device_pages=self.max_total_num_tokens // server_args.block_size, max_scheduled_tokens=server_args.chunked_prefill_size, @@ -293,6 +296,7 @@ def __init__( mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, paged_cache_groups=pool_to_paged_cache_groups(token_to_kv_pool), + enable_mixed_prefill_decode=enable_mixed_prefill_decode, ) logger.info( "Scheduler config: page_size=%s num_device_pages=%s " @@ -785,8 +789,10 @@ def _commit_forward_results( on_first_token=None, ): self.request_handler.forward_ct += 1 - forward_mode = ( - ForwardMode.EXTEND if forward_op.num_extends() > 0 else ForwardMode.DECODE + forward_mode = ForwardMode.from_num_extends( + forward_op.num_extends(), + len(forward_op.request_ids), + has_drafter=self.server_args.speculative_algorithm is not None, ) self.request_handler._profile_batch_predicate(forward_mode) @@ -859,12 +865,12 @@ def _dp_sync_and_check(self, forward_op) -> DpForwardMetadata: batch_size = len(forward_op.request_ids) if forward_op is not None else 0 if forward_op is None: forward_mode = ForwardMode.IDLE - elif forward_op.num_extends() > 0: - forward_mode = ForwardMode.EXTEND - elif self.server_args.speculative_algorithm is not None: - forward_mode = ForwardMode.TARGET_VERIFY else: - forward_mode = ForwardMode.DECODE + forward_mode = ForwardMode.from_num_extends( + forward_op.num_extends(), + batch_size, + has_drafter=self.server_args.speculative_algorithm is not None, + ) self._dp_local_info[0, 0] = num_tokens self._dp_local_info[0, 1] = batch_size diff --git a/python/tokenspeed/runtime/engine/generation_output_processor.py b/python/tokenspeed/runtime/engine/generation_output_processor.py index 1234db283..2bcb1891c 100644 --- a/python/tokenspeed/runtime/engine/generation_output_processor.py +++ b/python/tokenspeed/runtime/engine/generation_output_processor.py @@ -483,7 +483,8 @@ def post_process_forward_op( forward_op.input_lengths, forward_op.extend_prefix_lens, ) - is_decode_op = forward_op.num_extends() <= 0 + num_extends = forward_op.num_extends() + is_decode_op = num_extends <= 0 request_changes = [] stream_out_rids = [] @@ -504,6 +505,7 @@ def post_process_forward_op( if output_logprobs_list is not None else None ) + is_decode_slot = i >= num_extends if self.spec_num_tokens is not None and is_decode_op: pt += self.spec_num_tokens else: @@ -524,7 +526,7 @@ def post_process_forward_op( if on_first_token is not None and model_output_ids: on_first_token(forward_op.request_pool_indices[i], model_output_ids[0]) - if is_decode_op and self.spec_algorithm is not None: + if is_decode_slot and self.spec_algorithm is not None: request_state.spec_verify_ct += 1 # With the capturable grammar pipeline the matcher is @@ -597,7 +599,7 @@ def post_process_forward_op( else: stream_out_rids.append(rid) stream_out_states.append(request_state) - if is_decode_op: + if is_decode_slot: request_changes.append( make_update_reserve_tokens_event(rid, output_length) ) diff --git a/python/tokenspeed/runtime/engine/output_processor.py b/python/tokenspeed/runtime/engine/output_processor.py index bd1f96f2f..e185ab28a 100644 --- a/python/tokenspeed/runtime/engine/output_processor.py +++ b/python/tokenspeed/runtime/engine/output_processor.py @@ -268,6 +268,15 @@ def handle_batch_output( if recv_obj.output_multi_ids is not None: output_multi_ids = recv_obj.output_multi_ids[i] + if len(recv_obj.batch_accept_draft_tokens) > 0: + meta_info.update( + { + "accept_draft_tokens": recv_obj.batch_accept_draft_tokens[ + i + ] + } + ) + out_dict = { "text": state.text, "output_ids": output_token_ids, diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 820c0a17a..653a1f191 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -66,6 +66,7 @@ def make_config( mamba_cache_chunk_size: int = 64, mamba_pool_total_chunks: int = 0, paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None, + enable_mixed_prefill_decode: bool = False, ) -> SchedulerConfig: cfg = SchedulerConfig() cfg.num_device_pages = num_device_pages @@ -92,6 +93,7 @@ def make_config( cfg.enable_mamba = enable_mamba cfg.mamba_cache_chunk_size = mamba_cache_chunk_size cfg.mamba_pool_total_chunks = mamba_pool_total_chunks + cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode if paged_cache_groups: cfg.paged_cache_groups = list(paged_cache_groups) return cfg diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index e3dcdc7b9..4402c79eb 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -546,6 +546,7 @@ def _pad_offsets_to_padded_bs( def _init_replay_metadata( self, padded_bs: int, + actual_bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, req_to_page: torch.Tensor, @@ -562,7 +563,7 @@ def _init_replay_metadata( "uses_paged_cache_groups", False, ): - actual_bs = next( + table_bs = next( ( int(table.shape[0]) for table in paged_cache_block_tables.values() @@ -572,7 +573,7 @@ def _init_replay_metadata( ) paged_cache_block_tables = self._pad_block_tables_to_padded_bs( paged_cache_block_tables, - actual_bs=actual_bs, + actual_bs=table_bs, padded_bs=padded_bs, ) kwargs["paged_cache_block_tables"] = paged_cache_block_tables @@ -585,6 +586,8 @@ def _init_replay_metadata( kwargs["paged_cache_block_table_base_offsets"] = ( paged_cache_block_table_base_offsets ) + if getattr(self.attn_backend, "uses_padded_decode_token_mask", False): + kwargs["actual_bs"] = actual_bs self.attn_backend.init_forward_metadata_replay_cuda_graph( padded_bs, req_pool_indices, @@ -785,6 +788,7 @@ def __call__( ) self._init_replay_metadata( padded_bs, + bs, req_pool_indices, seq_lens, req_to_page=req_to_page, @@ -831,6 +835,7 @@ def __call__( extend_prefix_lens_cpu=extend_prefix_lens_cpu, extend_seq_lens=extend_seq_lens, extend_seq_lens_cpu=extend_seq_lens_cpu, + num_extends=ctx.num_extends, positions=positions, out_cache_loc=out_cache_loc, global_num_tokens=ctx.global_num_tokens, diff --git a/python/tokenspeed/runtime/execution/forward_batch_info.py b/python/tokenspeed/runtime/execution/forward_batch_info.py index 07bc6ab1c..609003b0c 100755 --- a/python/tokenspeed/runtime/execution/forward_batch_info.py +++ b/python/tokenspeed/runtime/execution/forward_batch_info.py @@ -55,6 +55,9 @@ def is_extend(self): def is_decode(self): return self == ForwardMode.DECODE + def is_mixed(self): + return self == ForwardMode.MIXED + def is_idle(self): return self == ForwardMode.IDLE @@ -67,6 +70,19 @@ def is_draft_extend(self): def is_decode_or_idle(self): return self == ForwardMode.DECODE or self == ForwardMode.IDLE + @staticmethod + def from_num_extends( + num_extends: int, + batch_size: int, + *, + has_drafter: bool = False, + ) -> "ForwardMode": + if batch_size <= 0: + return ForwardMode.IDLE + if num_extends > 0: + return ForwardMode.MIXED if num_extends < batch_size else ForwardMode.EXTEND + return ForwardMode.TARGET_VERIFY if has_drafter else ForwardMode.DECODE + class CaptureHiddenMode(IntEnum): NULL = auto() diff --git a/python/tokenspeed/runtime/execution/input_buffer.py b/python/tokenspeed/runtime/execution/input_buffer.py index 37e04d952..88feca724 100644 --- a/python/tokenspeed/runtime/execution/input_buffer.py +++ b/python/tokenspeed/runtime/execution/input_buffer.py @@ -175,15 +175,19 @@ def fill_input_buffers( page_size=self.page_size, ) - valid_cache_lengths = runtime_states.valid_cache_lengths[ + cached_prefix_lens = runtime_states.valid_cache_lengths[ self.req_pool_indices_buf[:batch_size] ] - # Compute positions - prefix_lens = ( - self.extend_prefix_lens_buf[:num_extends] - if num_extends > 0 - else valid_cache_lengths - ) + # Compute positions. In mixed batches, prefill rows use their extend + # prefix lengths while decode rows use the current valid cache lengths. + prefill_prefix_lens = self.extend_prefix_lens_buf[:num_extends] + if num_extends == 0: + prefix_lens = cached_prefix_lens + elif num_extends == batch_size: + prefix_lens = prefill_prefix_lens + else: + prefix_lens = cached_prefix_lens.clone() + prefix_lens[:num_extends].copy_(prefill_prefix_lens) positions, _ = compute_position_triton( extend_prefix_lens=prefix_lens, extend_seq_lens=input_lengths_device, @@ -193,20 +197,55 @@ def fill_input_buffers( # Determine input_ids and forward_mode if num_extends > 0: + prefill_token_count = sum(forward_op.input_lengths[:num_extends]) input_ids_cpu = torch.tensor( forward_op.input_ids, device="cpu", pin_memory=True ) - self.input_ids_buf[:total_tokens].copy_( + self.input_ids_buf[:prefill_token_count].copy_( input_ids_cpu, non_blocking=True, ) shifted_ids_cpu = torch.tensor( forward_op.shifted_input_ids, device="cpu", pin_memory=True ) - self.shifted_prefill_ids_buf[:total_tokens].copy_( + self.shifted_prefill_ids_buf[:prefill_token_count].copy_( shifted_ids_cpu, non_blocking=True, ) + if num_extends < batch_size: + decode_req_pool_indices = req_pool_indices_device[ + num_extends:batch_size + ] + if forward_op.decode_input_ids is not None: + decode_count = batch_size - num_extends + if len(forward_op.decode_input_ids) != decode_count: + raise RuntimeError( + "mixed forward decode_input_ids length mismatch: " + f"got {len(forward_op.decode_input_ids)}, " + f"expected {decode_count}" + ) + decode_input_ids_tensor = torch.tensor( + forward_op.decode_input_ids, + dtype=torch.int32, + device="cpu", + pin_memory=True, + ).to(req_pool_indices_device.device, non_blocking=True) + mask = (decode_input_ids_tensor != -1).unsqueeze(1) + slot = runtime_states.future_input_map[decode_req_pool_indices, :1] + runtime_states.future_input_map[decode_req_pool_indices, :1] = ( + torch.where(mask, decode_input_ids_tensor.unsqueeze(1), slot) + ) + decode_ids = runtime_states.future_input_map[ + decode_req_pool_indices, :1 + ].flatten() + self.input_ids_buf[prefill_token_count:total_tokens].copy_( + decode_ids, + non_blocking=True, + ) + self.shifted_prefill_ids_buf[prefill_token_count:total_tokens].copy_( + decode_ids, + non_blocking=True, + ) else: # If the scheduler provides explicit decode input ids (!= -1), write # them into future_input_map before reading, so that they take effect @@ -230,7 +269,7 @@ def fill_input_buffers( non_blocking=True, ) - self.seq_lens_buf[:batch_size].copy_(input_lengths_device + valid_cache_lengths) + self.seq_lens_buf[:batch_size].copy_(input_lengths_device + cached_prefix_lens) # Reset positions beyond total_tokens to the dummy KV slot so that any # CUDA graph replay with a larger (padded) batch size writes padding diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index b224124fc..0eda2f0f6 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -829,14 +829,12 @@ def execute_forward_op( total_tokens=total_tokens, ) - if num_extends > 0: - forward_mode = ForwardMode.EXTEND - elif self.drafter is not None: - forward_mode = ForwardMode.TARGET_VERIFY - else: - forward_mode = ForwardMode.DECODE - bs = len(forward_op.request_ids) + forward_mode = ForwardMode.from_num_extends( + num_extends, + bs, + has_drafter=self.drafter is not None, + ) if self.runtime_states.mamba_pool is not None and ( num_extends > 0 or has_retract diff --git a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py index 885ed39e5..f53edb274 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py @@ -14,6 +14,9 @@ from __future__ import annotations import torch +from tokenspeed_kernel.ops.attention.triton.deepseek_v4 import ( + deepseek_v4_indexer_decode_metadata_compute, +) from tokenspeed.runtime.configs.model_config import AttentionArch from tokenspeed.runtime.execution.forward_batch_info import ForwardMode @@ -56,10 +59,167 @@ def _cu_seqlens(lengths: torch.Tensor) -> torch.Tensor: ) +def _decode_positions_from_metadata( + metadata: DeepseekV4ForwardMetadata, + num_tokens: int, +) -> torch.Tensor: + token_to_req = metadata.token_to_req_indices[:num_tokens].to(torch.int64) + query_starts = metadata.query_start_loc[token_to_req].to(torch.int64) + query_lens = metadata.query_lens[token_to_req].to(torch.int64) + seq_lens = metadata.seq_lens[token_to_req].to(torch.int64) + token_offsets = torch.arange( + num_tokens, + dtype=torch.int64, + device=metadata.seq_lens.device, + ) + return seq_lens - query_lens + token_offsets - query_starts + + +def _refresh_decode_indexer_plan_cache( + metadata: DeepseekV4ForwardMetadata, + *, + max_context_len: int, +) -> None: + """Pre-build decode-indexer plan tensors before per-layer parallel work. + + This keeps per-layer indexer calls read-only with respect to cached plan + buffers while compressor work may run on an auxiliary stream. + """ + cache = metadata.decode_indexer_plan_cache + if not cache: + return + refreshed_keys = metadata.decode_indexer_plan_refreshed_keys + refreshed_keys.clear() + for ( + compress_ratio, + cache_block_size, + num_tokens, + ), plan in list(cache.items()): + if num_tokens <= 0: + plan.context_lens.zero_() + plan.block_table.zero_() + plan.max_context_len = 0 + refreshed_keys.add((compress_ratio, cache_block_size, num_tokens)) + continue + positions = _decode_positions_from_metadata(metadata, num_tokens) + token_to_req_indices = metadata.token_to_req_indices[:num_tokens] + block_table = metadata.compressed_block_table( + compress_ratio, + cache_block_size, + ) + rows = int(block_table.shape[0]) if block_table.ndim >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.ndim >= 2 else 0 + if rows <= 0 or cols <= 0: + plan.context_lens.zero_() + plan.block_table.zero_() + plan.max_context_len = 0 + refreshed_keys.add((compress_ratio, cache_block_size, num_tokens)) + continue + max_blocks = int(plan.block_table.shape[1]) + if max_context_len > 0: + derived_max_len = max( + 1, + (max_context_len + compress_ratio - 1) // compress_ratio, + ) + else: + derived_max_len = max( + 1, + (block_table.shape[1] * cache_block_size + compress_ratio - 1) + // compress_ratio, + ) + if plan.max_context_len != derived_max_len: + plan.max_context_len = derived_max_len + deepseek_v4_indexer_decode_metadata_compute( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + max_blocks=max_blocks, + out_context_lens=plan.context_lens, + out_block_tables=plan.block_table, + ) + if metadata.is_valid_token is not None: + valid = metadata.is_valid_token[:num_tokens].to( + device=plan.context_lens.device, + dtype=torch.bool, + ) + with torch.inference_mode(): + plan.context_lens.masked_fill_(~valid.view(num_tokens, 1), 0) + plan.block_table.masked_fill_( + ~valid.to(device=plan.block_table.device).view(num_tokens, 1), + 0, + ) + refreshed_keys.add((compress_ratio, cache_block_size, num_tokens)) + + +def _refresh_decode_indexer_schedule_metadata( + metadata: DeepseekV4ForwardMetadata, +) -> None: + if not metadata.decode_indexer_schedule_metadata: + return + try: + from tokenspeed_kernel.thirdparty import deep_gemm + except Exception: + return + get_metadata = getattr(deep_gemm, "get_paged_mqa_logits_metadata", None) + if get_metadata is None: + return + for ( + compress_ratio, + cache_block_size, + num_tokens, + ), schedule_metadata in list(metadata.decode_indexer_schedule_metadata.items()): + if num_tokens <= 0: + continue + key = (compress_ratio, cache_block_size, num_tokens) + decode_plan = metadata.decode_indexer_plan_cache.get(key) + context_lens = getattr(decode_plan, "context_lens", None) + if ( + context_lens is not None + and context_lens.shape == (num_tokens, 1) + and context_lens.dtype == torch.int32 + ): + context_lens = context_lens.contiguous() + else: + positions = _decode_positions_from_metadata(metadata, num_tokens) + compressed_lens = torch.div( + positions.to(torch.int32) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + if metadata.is_valid_token is not None: + valid = metadata.is_valid_token[:num_tokens].to( + device=compressed_lens.device, + dtype=torch.bool, + ) + compressed_lens = torch.where( + valid, + compressed_lens, + torch.zeros_like(compressed_lens), + ) + context_lens = compressed_lens.view(num_tokens, 1).contiguous() + refreshed = get_metadata( + context_lens, + cache_block_size, + deep_gemm.get_num_sms(), + ) + if ( + schedule_metadata.shape == refreshed.shape + and schedule_metadata.device == refreshed.device + and schedule_metadata.dtype == refreshed.dtype + ): + with torch.inference_mode(): + schedule_metadata.copy_(refreshed) + else: + metadata.decode_indexer_schedule_metadata[key] = refreshed + + class DeepseekV4AttentionBackend(AttentionBackend): """Metadata owner for the model-local DeepSeek V4 attention path.""" uses_paged_cache_groups = True + uses_padded_decode_token_mask = True def __init__(self, config) -> None: super().__init__(config) @@ -127,12 +287,37 @@ def _query_lens( bs: int, seq_lens: torch.Tensor, forward_mode: ForwardMode | None, + num_extends: int, extend_seq_lens_cpu: torch.Tensor | None, extend_prefix_lens_cpu: torch.Tensor | None, extend_prefix_lens: torch.Tensor | None, ) -> torch.Tensor: if forward_mode is not None and forward_mode.is_decode_or_idle(): return torch.ones(bs, dtype=torch.int32, device=seq_lens.device) + if forward_mode is not None and forward_mode.is_mixed(): + lens = torch.ones(bs, dtype=torch.int32, device=seq_lens.device) + num_prefill_reqs = max(0, min(int(num_extends), bs)) + if num_prefill_reqs == 0: + return lens + if extend_seq_lens_cpu is not None and extend_seq_lens_cpu.numel() > 0: + lens[:num_prefill_reqs] = extend_seq_lens_cpu[:num_prefill_reqs].to( + seq_lens.device, dtype=torch.int32 + ) + elif extend_prefix_lens_cpu is not None: + prefix = extend_prefix_lens_cpu[:num_prefill_reqs].to( + seq_lens.device, dtype=torch.int32 + ) + lens[:num_prefill_reqs] = ( + seq_lens[:num_prefill_reqs].to(torch.int32) - prefix + ).clamp_min(0) + elif extend_prefix_lens is not None: + prefix = extend_prefix_lens[:num_prefill_reqs].to(torch.int32) + lens[:num_prefill_reqs] = ( + seq_lens[:num_prefill_reqs].to(torch.int32) - prefix + ).clamp_min(0) + else: + lens[:num_prefill_reqs] = seq_lens[:num_prefill_reqs].to(torch.int32) + return lens if extend_seq_lens_cpu is not None: return extend_seq_lens_cpu[:bs].to(seq_lens.device, dtype=torch.int32) if extend_prefix_lens_cpu is not None: @@ -143,6 +328,33 @@ def _query_lens( return (seq_lens[:bs].to(torch.int32) - prefix).clamp_min(0) return seq_lens[:bs].to(torch.int32) + def _query_lens_cpu( + self, + bs: int, + forward_mode: Optional[ForwardMode], + num_extends: int, + extend_seq_lens_cpu: Optional[torch.Tensor], + extend_prefix_lens_cpu: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + if forward_mode is not None and forward_mode.is_decode_or_idle(): + return torch.ones(bs, dtype=torch.int32) + if forward_mode is not None and forward_mode.is_mixed(): + lens = torch.ones(bs, dtype=torch.int32) + num_prefill_reqs = max(0, min(int(num_extends), bs)) + if num_prefill_reqs == 0: + return lens + if extend_seq_lens_cpu is None: + return None + lens[:num_prefill_reqs] = extend_seq_lens_cpu[:num_prefill_reqs].to( + dtype=torch.int32, device="cpu" + ) + return lens + if extend_seq_lens_cpu is not None: + return extend_seq_lens_cpu[:bs].to(dtype=torch.int32, device="cpu") + if extend_prefix_lens_cpu is not None: + return None + return None + def init_forward_metadata( self, bs: int, @@ -160,6 +372,8 @@ def init_forward_metadata( paged_cache_block_table_base_offsets = ( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) + num_extends_arg = kwargs.pop("num_extends", None) + num_extends = bs if num_extends_arg is None else int(num_extends_arg) del num_tokens, kwargs device = seq_lens.device req_pool_indices = req_pool_indices[:bs] @@ -168,10 +382,51 @@ def init_forward_metadata( bs, seq_lens, forward_mode, + num_extends, extend_seq_lens_cpu, extend_prefix_lens_cpu, extend_prefix_lens, ) + if forward_mode is not None and forward_mode.is_mixed(): + num_prefill_reqs = max(0, min(num_extends, bs)) + elif forward_mode is not None and forward_mode.is_extend(): + num_prefill_reqs = bs + else: + num_prefill_reqs = 0 + query_lens_cpu = self._query_lens_cpu( + bs, + forward_mode, + num_extends, + extend_seq_lens_cpu, + extend_prefix_lens_cpu, + ) + seq_lens_cpu = None + if extend_prefix_lens_cpu is not None and query_lens_cpu is not None: + seq_lens_cpu = seq_lens[:bs].to(dtype=torch.int32, device="cpu") + prefix_count = min( + int(extend_prefix_lens_cpu.numel()), + ( + num_prefill_reqs + if forward_mode is not None and forward_mode.is_mixed() + else bs + ), + ) + if prefix_count: + seq_lens_cpu[:prefix_count] = ( + extend_prefix_lens_cpu[:prefix_count].to( + dtype=torch.int32, + device="cpu", + ) + + query_lens_cpu[:prefix_count] + ) + elif extend_seq_lens_cpu is not None and forward_mode is not None: + if forward_mode.is_extend(): + seq_lens_cpu = extend_seq_lens_cpu[:bs].to( + dtype=torch.int32, + device="cpu", + ) + elif forward_mode.is_mixed(): + seq_lens_cpu = seq_lens[:bs].to(dtype=torch.int32, device="cpu") max_seq_len = int(seq_lens.max().item()) if bs else 0 max_pages = (max_seq_len + self.page_size - 1) // self.page_size if req_to_page is None: @@ -210,6 +465,9 @@ def init_forward_metadata( ) req_ids = torch.arange(bs, device=device, dtype=torch.int32) token_to_req = torch.repeat_interleave(req_ids, query_lens.clamp_min(0)) + num_prefill_tokens = ( + int(query_lens[:num_prefill_reqs].sum().item()) if num_prefill_reqs else 0 + ) self.forward_metadata = DeepseekV4ForwardMetadata( page_size=self.page_size, req_pool_indices=req_pool_indices, @@ -218,6 +476,10 @@ def init_forward_metadata( query_lens=query_lens, query_start_loc=_cu_seqlens(query_lens), token_to_req_indices=token_to_req, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + num_prefill_reqs=num_prefill_reqs, + num_prefill_tokens=num_prefill_tokens, forward_mode=forward_mode, paged_cache_block_tables=paged_cache_block_tables, paged_cache_block_table_base_offsets=base_offsets_on_device, @@ -271,6 +533,7 @@ def _update_decode_swa_metadata( block_table_base_offsets=metadata.swa_base_logical_page, window_size=window_size, block_size=block_size, + is_valid_token=metadata.is_valid_token, out_indices=metadata.decode_swa_indices, out_lens=metadata.decode_swa_lens, ) @@ -304,7 +567,7 @@ def _get_decode_swa_metadata( block_size=block_size, ) - def _decode_compressed_indices_and_lens( + def _decode_compressed_attention_indices_and_lens( self, positions: torch.Tensor, *, @@ -320,6 +583,12 @@ def _decode_compressed_indices_and_lens( num_tokens = positions.numel() req_idx = metadata.token_to_req_indices[:num_tokens].to(torch.int64) block_table = metadata.compressed_block_table(compress_ratio, block_size) + is_valid_token = ( + metadata.is_valid_token[:num_tokens] + if metadata.is_valid_token is not None + else None + ) + capturing = positions.is_cuda and torch.cuda.is_current_stream_capturing() if compress_ratio == 4: if topk_indices is None: raise RuntimeError("DeepSeek V4 CSA decode requires top-k indices") @@ -328,19 +597,41 @@ def _decode_compressed_indices_and_lens( token_to_req_indices=metadata.token_to_req_indices[:num_tokens], block_table=block_table, block_size=block_size, + is_valid_token=is_valid_token, ) return indices_2d.unsqueeze(1), lens - else: - width = self._dense_compressed_indices_width(compress_ratio) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp(0, width) - offsets = torch.arange(width, dtype=torch.int64, device=positions.device) - local = offsets[None, :].expand(num_tokens, -1) - valid = offsets[None, :] < compressed_lens[:, None] - lens = compressed_lens.to(torch.int32) + + cache_key = ( + int(compress_ratio), + int(block_size), + int(num_tokens), + int(positions.data_ptr()) if positions.numel() else 0, + ) + dense_indices_cache = metadata.decode_dense_compressed_indices_cache + capture_safe_keys = metadata.decode_dense_compressed_indices_capture_safe_keys + cached = dense_indices_cache.get(cache_key) + capture_cached = cache_key in capture_safe_keys + if cached is not None and (not capturing or capture_cached): + return cached + + width = self._dense_compressed_indices_width(compress_ratio) + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp(0, width) + offsets = torch.arange(width, dtype=torch.int64, device=positions.device) + local = offsets[None, :].expand(num_tokens, -1) + valid = offsets[None, :] < compressed_lens[:, None] + if is_valid_token is not None: + valid = valid & is_valid_token.to(torch.bool)[:, None] + lens = compressed_lens.to(torch.int32) + if is_valid_token is not None: + lens = torch.where( + is_valid_token.to(torch.bool), + lens, + torch.zeros_like(lens), + ) safe_local = torch.where(valid, local, torch.zeros_like(local)) pages = torch.div(safe_local, block_size, rounding_mode="floor") @@ -353,6 +644,9 @@ def _decode_compressed_indices_and_lens( torch.full_like(slots, -1), ) indices = indices_2d.to(torch.int32).unsqueeze(1) + dense_indices_cache[cache_key] = (indices, lens) + if capturing: + capture_safe_keys.add(cache_key) return indices, lens def _dense_compressed_indices_width(self, compress_ratio: int) -> int: @@ -477,7 +771,7 @@ def forward_deepseek_v4_decode( block_size=token_to_kv_pool.swa_block_size, ) compressed_block_size = token_to_kv_pool.get_compressed_block_size(layer_id) - extra_indices, extra_lens = self._decode_compressed_indices_and_lens( + extra_indices, extra_lens = self._decode_compressed_attention_indices_and_lens( positions, compress_ratio=compress_ratio, block_size=compressed_block_size, @@ -518,6 +812,103 @@ def forward_deepseek_v4_decode( out = out.squeeze(1) return out[:, :num_local_heads] + def forward_deepseek_v4_mixed( + self, + *, + q: torch.Tensor, + positions: torch.Tensor, + token_to_kv_pool, + layer_id: int, + kind: str, + compress_ratio: int, + num_local_heads: int, + padded_heads: int, + head_dim: int, + window_size: int, + softmax_scale: float, + attn_sink: torch.Tensor, + topk_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + metadata = self.forward_metadata + if metadata is None: + raise RuntimeError("DeepSeek V4 mixed attention requires forward metadata") + if metadata.forward_mode is None or not metadata.forward_mode.is_mixed(): + raise RuntimeError( + "forward_deepseek_v4_mixed only supports ForwardMode.MIXED" + ) + + num_prefill_reqs = metadata.num_prefill_reqs + num_prefill_tokens = metadata.num_prefill_tokens + num_decode_reqs = metadata.decode_req_count() + num_decode_tokens = metadata.decode_token_count() + out = q.new_empty((q.shape[0], num_local_heads, head_dim)) + saved_metadata = self.forward_metadata + try: + if num_prefill_tokens > 0: + self.forward_metadata = self._metadata_slice( + metadata, + req_start=0, + req_end=num_prefill_reqs, + token_start=0, + token_end=num_prefill_tokens, + forward_mode=ForwardMode.EXTEND, + ) + prefill_out = self.forward_deepseek_v4_prefill( + q=q[:num_prefill_tokens], + positions=positions[:num_prefill_tokens], + token_to_kv_pool=token_to_kv_pool, + layer_id=layer_id, + kind=kind, + compress_ratio=compress_ratio, + num_local_heads=num_local_heads, + padded_heads=padded_heads, + head_dim=head_dim, + window_size=window_size, + softmax_scale=softmax_scale, + attn_sink=attn_sink, + topk_indices=( + topk_indices[:num_prefill_tokens] + if topk_indices is not None + else None + ), + ) + with deepseek_v4_profile_scope(f"attn_{kind}_mixed_prefill_copy"): + out[:num_prefill_tokens].copy_(prefill_out) + if num_decode_tokens > 0: + decode_end = num_prefill_tokens + num_decode_tokens + self.forward_metadata = self._metadata_slice( + metadata, + req_start=num_prefill_reqs, + req_end=num_prefill_reqs + num_decode_reqs, + token_start=num_prefill_tokens, + token_end=decode_end, + forward_mode=ForwardMode.DECODE, + ) + decode_out = self.forward_deepseek_v4_decode( + q=q[num_prefill_tokens:decode_end], + positions=positions[num_prefill_tokens:decode_end], + token_to_kv_pool=token_to_kv_pool, + layer_id=layer_id, + kind=kind, + compress_ratio=compress_ratio, + num_local_heads=num_local_heads, + padded_heads=padded_heads, + head_dim=head_dim, + window_size=window_size, + softmax_scale=softmax_scale, + attn_sink=attn_sink, + topk_indices=( + topk_indices[num_prefill_tokens:decode_end] + if topk_indices is not None + else None + ), + ) + with deepseek_v4_profile_scope(f"attn_{kind}_mixed_decode_copy"): + out[num_prefill_tokens:decode_end].copy_(decode_out) + finally: + self.forward_metadata = saved_metadata + return out + def _prefill_gather_lens( self, *, @@ -700,6 +1091,10 @@ def _metadata_slice( key: offsets[req_start:req_end] for key, offsets in metadata.compressor_state_base_logical_pages.items() } + req_count = max(0, req_end - req_start) + token_count = max(0, token_end - token_start) + num_prefill_reqs = req_count if forward_mode.is_extend() else 0 + num_prefill_tokens = token_count if forward_mode.is_extend() else 0 return DeepseekV4ForwardMetadata( page_size=metadata.page_size, req_pool_indices=metadata.req_pool_indices[req_start:req_end], @@ -708,6 +1103,23 @@ def _metadata_slice( query_lens=metadata.query_lens[req_start:req_end], query_start_loc=_cu_seqlens(metadata.query_lens[req_start:req_end]), token_to_req_indices=token_to_req, + is_valid_token=( + metadata.is_valid_token[token_start:token_end] + if metadata.is_valid_token is not None + else None + ), + seq_lens_cpu=( + metadata.seq_lens_cpu[req_start:req_end] + if metadata.seq_lens_cpu is not None + else None + ), + query_lens_cpu=( + metadata.query_lens_cpu[req_start:req_end] + if metadata.query_lens_cpu is not None + else None + ), + num_prefill_reqs=num_prefill_reqs, + num_prefill_tokens=num_prefill_tokens, forward_mode=forward_mode, paged_cache_block_tables=paged_cache_block_tables, paged_cache_block_table_base_offsets=paged_cache_block_table_base_offsets, @@ -955,6 +1367,11 @@ def init_cuda_graph_state( dtype=torch.int32, device=self.device, ) + self._cuda_graph_is_valid_token = torch.ones( + max_bs, + dtype=torch.bool, + device=self.device, + ) def _refresh_cuda_graph_paged_cache_block_tables( self, @@ -1077,6 +1494,9 @@ def init_forward_metadata_capture_cuda_graph( query_lens=self._cuda_graph_query_lens[:bs], query_start_loc=self._cuda_graph_query_start_loc[: bs + 1], token_to_req_indices=self._cuda_graph_token_to_req[:bs], + is_valid_token=self._cuda_graph_is_valid_token[:bs], + seq_lens_cpu=None, + query_lens_cpu=None, forward_mode=forward_mode, paged_cache_block_tables=metadata_paged, paged_cache_block_table_base_offsets=metadata_base_offsets, @@ -1103,6 +1523,7 @@ def init_forward_metadata_replay_cuda_graph( paged_cache_block_table_base_offsets = ( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) + actual_bs = max(0, min(int(kwargs.pop("actual_bs", bs)), bs)) del kwargs if forward_mode is not None and not forward_mode.is_decode_or_idle(): raise NotImplementedError( @@ -1118,6 +1539,9 @@ def init_forward_metadata_replay_cuda_graph( self._cuda_graph_token_to_req[:bs].copy_( torch.arange(bs, dtype=torch.int32, device=self.device) ) + self._cuda_graph_is_valid_token[:actual_bs].fill_(True) + if actual_bs < bs: + self._cuda_graph_is_valid_token[actual_bs:bs].fill_(False) if req_to_page is not None: self._cuda_graph_block_table[:bs, : self.max_num_pages].copy_( req_to_page[req_pool_indices[:bs], : self.max_num_pages] @@ -1159,6 +1583,8 @@ def init_forward_metadata_replay_cuda_graph( metadata.compressor_state_base_logical_pages = compressor_state_base metadata.indexer_state_block_table = indexer_state_block_table metadata.indexer_state_base_logical_page = indexer_state_base + metadata.num_prefill_reqs = 0 + metadata.num_prefill_tokens = 0 if ( forward_mode is not None and forward_mode.is_decode() @@ -1171,6 +1597,11 @@ def init_forward_metadata_replay_cuda_graph( block_size=self._decode_swa_block_size, ) metadata.refresh_decode_compressed_slot_mappings() + _refresh_decode_indexer_plan_cache( + metadata, + max_context_len=self.context_len, + ) + _refresh_decode_indexer_schedule_metadata(metadata) self.forward_metadata = metadata def advance_draft_forward_metadata(self): diff --git a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py index e14ce6d04..6b057bb9b 100644 --- a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py +++ b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py @@ -2203,11 +2203,18 @@ def _deepseek_v4_compute_global_topk_indices_and_lens_kernel( token_to_req_indices_ptr, block_table_ptr, block_table_stride, + is_valid_token_ptr, + has_valid_token: tl.constexpr, block_size: tl.constexpr, topk: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) + if has_valid_token: + is_valid_token = tl.load(is_valid_token_ptr + token_idx) + if not is_valid_token: + tl.store(topk_lens_ptr + token_idx, 0) + return req_idx = tl.load(token_to_req_indices_ptr + token_idx) count = tl.zeros((), dtype=tl.int32) @@ -2245,6 +2252,7 @@ def deepseek_v4_compute_global_topk_indices_and_lens( token_to_req_indices: torch.Tensor, block_table: torch.Tensor, block_size: int, + is_valid_token: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local CSA top-k indices to global KV slots in one Triton kernel.""" @@ -2257,13 +2265,31 @@ def deepseek_v4_compute_global_topk_indices_and_lens( topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) if num_tokens == 0: return global_topk_indices, topk_lens + if is_valid_token is not None: + is_valid_token = is_valid_token[:num_tokens].to( + device=topk_indices.device, + dtype=torch.bool, + ) if not topk_indices.is_cuda: valid = topk_indices >= 0 + if is_valid_token is not None: + valid = valid & is_valid_token[:, None] req_idx = token_to_req_indices[:num_tokens].to(torch.int64) + rows = int(block_table.shape[0]) if block_table.dim() >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.dim() >= 2 else 0 + if rows <= 0 or cols <= 0: + global_topk_indices.fill_(-1) + topk_lens.zero_() + return global_topk_indices, topk_lens safe_local = torch.where(valid, topk_indices, torch.zeros_like(topk_indices)) block_indices = torch.div(safe_local, block_size, rounding_mode="floor") block_offsets = safe_local % block_size - block_numbers = block_table[req_idx[:, None], block_indices.long()] + req_valid = (req_idx >= 0) & (req_idx < rows) + block_valid = (block_indices >= 0) & (block_indices < cols) + valid = valid & req_valid[:, None] & block_valid + safe_req = req_idx.clamp(0, rows - 1) + safe_block = block_indices.long().clamp(0, cols - 1) + block_numbers = block_table[safe_req[:, None], safe_block] global_topk_indices.copy_( torch.where( valid, @@ -2273,6 +2299,8 @@ def deepseek_v4_compute_global_topk_indices_and_lens( ) topk_lens.copy_(valid.sum(dim=1, dtype=torch.int32)) return global_topk_indices, topk_lens + if is_valid_token is None: + is_valid_token = torch.empty(0, dtype=torch.bool, device=topk_indices.device) _deepseek_v4_compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, @@ -2283,6 +2311,8 @@ def deepseek_v4_compute_global_topk_indices_and_lens( token_to_req_indices.to(torch.int32), block_table.to(torch.int32), block_table.stride(0), + is_valid_token, + is_valid_token.numel() != 0, block_size=block_size, topk=topk_indices.shape[-1], TRITON_BLOCK_SIZE=1024, @@ -2591,15 +2621,22 @@ def _deepseek_v4_decode_swa_indices_and_lens_kernel( query_start_loc_ptr, seq_lens_ptr, token_to_req_indices_ptr, + is_valid_token_ptr, block_table_ptr, block_table_base_offsets_ptr, block_table_stride, max_blocks_per_seq: tl.constexpr, + has_valid_token: tl.constexpr, window_size: tl.constexpr, block_size: tl.constexpr, candidate_block: tl.constexpr, ): token_idx = tl.program_id(0) + if has_valid_token: + is_valid = tl.load(is_valid_token_ptr + token_idx) + if not is_valid: + tl.store(swa_lens_ptr + token_idx, 0) + return req_idx = tl.load(token_to_req_indices_ptr + token_idx).to(tl.int32) query_start = tl.load(query_start_loc_ptr + req_idx).to(tl.int32) @@ -2647,6 +2684,7 @@ def deepseek_v4_decode_swa_indices_and_lens( window_size: int, block_size: int, block_table_base_offsets: torch.Tensor | None = None, + is_valid_token: torch.Tensor | None = None, out_indices: torch.Tensor | None = None, out_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -2663,6 +2701,13 @@ def deepseek_v4_decode_swa_indices_and_lens( out_lens = torch.empty(num_tokens, dtype=torch.int32, device=seq_lens.device) if num_tokens == 0: return out_indices, out_lens + if is_valid_token is None: + is_valid_token = torch.empty(0, dtype=torch.bool, device=seq_lens.device) + else: + is_valid_token = is_valid_token[:num_tokens].to( + device=seq_lens.device, + dtype=torch.bool, + ) candidate_block = min(1024, triton.next_power_of_2(window_size)) _deepseek_v4_decode_swa_indices_and_lens_kernel[(num_tokens,)]( @@ -2672,6 +2717,7 @@ def deepseek_v4_decode_swa_indices_and_lens( query_start_loc.to(torch.int32), seq_lens.to(torch.int32), token_to_req_indices.to(torch.int32), + is_valid_token, block_table.to(torch.int32), ( block_table_base_offsets.to(torch.int32) @@ -2680,6 +2726,7 @@ def deepseek_v4_decode_swa_indices_and_lens( ), block_table.stride(0), block_table.shape[-1], + is_valid_token.numel() != 0, window_size=window_size, block_size=block_size, candidate_block=candidate_block, diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py index bcec744af..bef078ef1 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py @@ -234,6 +234,15 @@ class DeepseekV4ForwardMetadata: query_lens: torch.Tensor query_start_loc: torch.Tensor token_to_req_indices: torch.Tensor + # Padding mask for CUDA graph replay rows; this is not mixed-batch state. + is_valid_token: Optional[torch.Tensor] = None + # CPU lens are retained for sparse prefill/indexer planning without + # forcing another device-to-host sync in the model path. + seq_lens_cpu: Optional[torch.Tensor] = None + query_lens_cpu: Optional[torch.Tensor] = None + # Cached split boundary derived from scheduler num_extends/query_lens. + num_prefill_reqs: int = 0 + num_prefill_tokens: int = 0 forward_mode: object = None decode_swa_indices: torch.Tensor | None = None decode_swa_lens: torch.Tensor | None = None @@ -257,9 +266,35 @@ class DeepseekV4ForwardMetadata: decode_compressed_slot_mappings: dict[tuple[int, int], torch.Tensor] = field( default_factory=dict ) + # Cache for dense compressed decode attention indices/lens. CSA decode uses + # dynamic top-k indices and does not populate this cache. + decode_dense_compressed_indices_cache: dict[ + tuple[int, int, int, int], tuple[torch.Tensor, torch.Tensor] + ] = field(default_factory=dict) + decode_dense_compressed_indices_capture_safe_keys: set[ + tuple[int, int, int, int] + ] = field(default_factory=set) decode_indexer_schedule_metadata: dict[tuple[int, int, int], torch.Tensor] = field( default_factory=dict ) + decode_indexer_plan_cache: dict[tuple[int, int, int], Any] = field( + default_factory=dict + ) + decode_indexer_plan_refreshed_keys: set[tuple[int, int, int]] = field( + default_factory=set + ) + prefill_indexer_plan_cache: dict[tuple[int, int, int], Any] = field( + default_factory=dict + ) + + def decode_req_count(self) -> int: + return max(0, int(self.req_pool_indices.shape[0]) - int(self.num_prefill_reqs)) + + def decode_token_count(self) -> int: + return max( + 0, + int(self.token_to_req_indices.shape[0]) - int(self.num_prefill_tokens), + ) def _use_decode_compressed_slot_cache(self, positions: torch.Tensor) -> bool: return ( diff --git a/python/tokenspeed/runtime/models/deepseek_v4.py b/python/tokenspeed/runtime/models/deepseek_v4.py index 9fe9db96d..4da61c90e 100644 --- a/python/tokenspeed/runtime/models/deepseek_v4.py +++ b/python/tokenspeed/runtime/models/deepseek_v4.py @@ -38,8 +38,6 @@ import torch import torch.nn.functional as F -import triton -import triton.language as tl try: # Optional dependency; the module-level wrapper imports the external @@ -49,7 +47,13 @@ except ImportError: deep_gemm = None # type: ignore[assignment] +from tokenspeed_kernel.ops.attention.triton.deepseek_v4 import ( + deepseek_v4_indexer_decode_metadata_compute, +) from tokenspeed_kernel.ops.gemm.fp8_utils import per_token_group_quant_fp8 +from tokenspeed_kernel.ops.moe.triton import ( + stage_deepseek_v4_mega_moe_inputs as _stage_deepseek_v4_mega_moe_inputs, +) from tokenspeed_kernel.ops.routing.cuda import dsv3_router_gemm from tokenspeed_kernel.platform import current_platform from tokenspeed_kernel.thirdparty.cuda import ( @@ -122,6 +126,7 @@ get_colorful_logger, set_weight_attrs, ) +from tokenspeed.runtime.utils.custom_ops import direct_register_custom_op from tokenspeed.runtime.utils.env import global_server_args_dict, pdl_enabled _platform = current_platform() @@ -717,27 +722,186 @@ def _deepseek_v4_indexer_topk_from_cache_batched( return topk +@dataclass(frozen=True) +class _DeepseekV4IndexerPrefillChunk: + token_start: int + token_end: int + req_start: int + req_end: int + query_start: int + query_end: int + skip_kv_gather: bool = False + + +@dataclass(frozen=True) +class _DeepseekV4IndexerPrefillMetadata: + chunk_bounds: torch.Tensor + chunk_plan: torch.Tensor + slots: torch.Tensor + cu_seq_lens: torch.Tensor + cu_start: torch.Tensor + cu_end: torch.Tensor + row_lens: torch.Tensor + + +@dataclass +class _DeepseekV4IndexerDecodeMetadata: + context_lens: torch.Tensor + block_table: torch.Tensor + max_context_len: int + + +def _deepseek_v4_indexer_prefill_max_logits_bytes( + max_logits_bytes: Optional[int] = None, +) -> int: + if max_logits_bytes is not None: + return max(1, int(max_logits_bytes)) + max_logits_mb = global_server_args_dict.get( + "deepseek_v4_indexer_prefill_max_logits_mb", + _DEEPSEEK_V4_INDEXER_PREFILL_MAX_LOGITS_MB, + ) + return max(1, int(max_logits_mb) * 1024 * 1024) + + +def _deepseek_v4_indexer_prefill_workspace_size( + seq_lens_cpu: torch.Tensor, + workspace_size: Optional[int] = None, +) -> int: + if workspace_size is not None: + return max(1, int(workspace_size)) + context_len = global_server_args_dict.get("context_length") + if isinstance(context_len, int) and context_len > 0: + return context_len * 40 + max_seq_len = int(seq_lens_cpu.max().item()) if seq_lens_cpu.numel() else 1 + return max(1, max_seq_len) * 40 + + +def _deepseek_v4_indexer_prefill_request_chunks( + *, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + compress_ratio: int, + num_tokens: int, + max_logits_bytes: Optional[int] = None, + workspace_size: Optional[int] = None, + request_offset: int = 0, +) -> list[_DeepseekV4IndexerPrefillChunk]: + """Build request/query-slice sparse-indexer prefill chunks.""" + + if num_tokens == 0: + return [] + + seq_lens = seq_lens_cpu.detach().cpu().to(torch.int64) + query_lens = query_lens_cpu.detach().cpu().to(torch.int64) + if seq_lens.numel() != query_lens.numel(): + return [] + + query_lens_list = [max(0, int(x)) for x in query_lens.tolist()] + if sum(query_lens_list) != num_tokens: + return [] + + compressed_seq_lens = torch.div( + seq_lens, + max(1, int(compress_ratio)), + rounding_mode="floor", + ) + compressed_seq_lens_list = [max(0, int(x)) for x in compressed_seq_lens.tolist()] + workspace_rows = _deepseek_v4_indexer_prefill_workspace_size( + seq_lens, + workspace_size, + ) + max_logits_elems = ( + _deepseek_v4_indexer_prefill_max_logits_bytes(max_logits_bytes) // 4 + ) + max_logits_elems = max(1, max_logits_elems) + + query_offsets = [0] + for query_len in query_lens_list: + query_offsets.append(query_offsets[-1] + query_len) + + chunks: list[_DeepseekV4IndexerPrefillChunk] = [] + n_reqs = len(query_lens_list) + end = 0 + while end < n_reqs: + start = end + chunk_m = 0 + chunk_n = 0 + while end < n_reqs: + q_len = query_lens_list[end] + seq_len = compressed_seq_lens_list[end] + new_m = chunk_m + q_len + new_n = chunk_n + seq_len + if new_n <= workspace_rows and new_m * new_n <= max_logits_elems: + chunk_m = new_m + chunk_n = new_n + end += 1 + else: + break + + if end == start: + chunk_m = query_lens_list[end] + chunk_n = compressed_seq_lens_list[end] + end += 1 + + if chunk_m <= 0: + continue + + req_start = start + request_offset + req_end = end + request_offset + max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m + chunk_token_start = query_offsets[start] + for query_start in range(0, chunk_m, max_q): + query_end = min(query_start + max_q, chunk_m) + chunks.append( + _DeepseekV4IndexerPrefillChunk( + token_start=chunk_token_start + query_start, + token_end=chunk_token_start + query_end, + req_start=req_start, + req_end=req_end, + query_start=query_start, + query_end=query_end, + skip_kv_gather=query_start > 0, + ) + ) + return chunks + + def _deepseek_v4_indexer_prefill_topk_chunks( positions: torch.Tensor, compress_ratio: int, max_logits_bytes: int | None = None, + *, + seq_lens_cpu: Optional[torch.Tensor] = None, + query_lens_cpu: Optional[torch.Tensor] = None, ) -> list[tuple[int, int]]: num_tokens = positions.numel() if num_tokens == 0: return [] - if max_logits_bytes is None: - max_logits_mb = global_server_args_dict.get( - "deepseek_v4_indexer_prefill_max_logits_mb", - _DEEPSEEK_V4_INDEXER_PREFILL_MAX_LOGITS_MB, - ) - max_logits_bytes = max_logits_mb * 1024 * 1024 - max_logits_elems = max(1, int(max_logits_bytes) // 4) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp_min(0) - lengths = compressed_lens.detach().cpu().tolist() + max_logits_elems = max( + 1, + _deepseek_v4_indexer_prefill_max_logits_bytes(max_logits_bytes) // 4, + ) + lengths: Optional[list[int]] = None + if seq_lens_cpu is not None and query_lens_cpu is not None: + seq_lens_list = seq_lens_cpu.detach().cpu().tolist() + query_lens_list = query_lens_cpu.detach().cpu().tolist() + cpu_lengths: list[int] = [] + for seq_len, query_len in zip(seq_lens_list, query_lens_list): + total_len = int(seq_len) + query_len = max(0, int(query_len)) + prefix_len = max(0, total_len - query_len) + for query_offset in range(query_len): + cpu_lengths.append((prefix_len + query_offset + 1) // compress_ratio) + if len(cpu_lengths) == num_tokens: + lengths = cpu_lengths + + if lengths is None: + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + lengths = compressed_lens.detach().cpu().tolist() chunks: list[tuple[int, int]] = [] end = 0 @@ -884,15 +1048,456 @@ def _deepseek_v4_gather_indexer_mxfp4_cache( return values, scales +def _deepseek_v4_gather_paged_indexer_mxfp4_cache_available() -> bool: + global _DEEPSEEK_V4_PAGED_GATHER_CHECKED + global _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE + if _DEEPSEEK_V4_PAGED_GATHER_CHECKED: + return _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE + try: + from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( + has_indexer_mxfp4_paged_gather, + ) + except Exception: + _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE = False + else: + _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE = bool(has_indexer_mxfp4_paged_gather()) + _DEEPSEEK_V4_PAGED_GATHER_CHECKED = True + return _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE + + +def _deepseek_v4_gather_paged_indexer_mxfp4_cache( + cache_2d: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + block_size: int, + out: Optional[tuple[torch.Tensor, torch.Tensor]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + value_bytes = DEEPSEEK_V4_INDEXER_DIM // 2 + scale_bytes = DEEPSEEK_V4_INDEXER_DIM // DEEPSEEK_V4_MXFP4_BLOCK_SIZE + if out is None: + total_rows = int(cu_seq_lens[-1].item()) if cu_seq_lens.numel() else 0 + values = torch.empty( + (total_rows, value_bytes), + dtype=torch.uint8, + device=cache_2d.device, + ) + scales = torch.empty( + (total_rows, scale_bytes), + dtype=torch.uint8, + device=cache_2d.device, + ) + else: + if out[0].shape[0] != out[1].shape[0]: + raise ValueError( + "DeepSeek V4 paged gather workspace value/scale rows must match, " + f"got values={out[0].shape[0]}, scales={out[1].shape[0]}" + ) + total_rows = int(out[0].shape[0]) + values = out[0][:total_rows] + scales = out[1][:total_rows] + if total_rows == 0: + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + if ( + cache_2d.is_cuda + and block_table.is_cuda + and cu_seq_lens.is_cuda + and _deepseek_v4_gather_paged_indexer_mxfp4_cache_available() + ): + from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( + indexer_mxfp4_paged_gather, + ) + + indexer_mxfp4_paged_gather( + kv_cache=cache_2d, + values_out=values, + scales_out=scales, + block_table=block_table, + cu_seq_lens=cu_seq_lens, + cache_block_size=block_size, + ) + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + exact_rows = int(cu_seq_lens[-1].item()) if cu_seq_lens.numel() else 0 + if exact_rows <= 0: + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + req_lens = torch.diff(cu_seq_lens.to(torch.int64)) + req_ids = torch.repeat_interleave( + torch.arange(req_lens.numel(), device=cache_2d.device, dtype=torch.int64), + req_lens.to(device=cache_2d.device), + output_size=exact_rows, + ) + cu_seq_lens_device = cu_seq_lens.to(device=cache_2d.device, dtype=torch.int64) + local = torch.arange(exact_rows, device=cache_2d.device, dtype=torch.int64) + local = local - cu_seq_lens_device[:-1][req_ids] + pages = torch.div(local, block_size, rounding_mode="floor") + page_offsets = local % block_size + block_table_device = block_table.to(device=cache_2d.device, dtype=torch.int64) + slots = block_table_device[req_ids, pages] * block_size + page_offsets + _deepseek_v4_gather_indexer_mxfp4_cache( + cache_2d, + slots, + block_size, + out=(values[:exact_rows], scales[:exact_rows]), + ) + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + +def _deepseek_v4_indexer_prefill_gather_plan( + *, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + num_tokens = positions.numel() + device = positions.device + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + if num_tokens == 0: + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, empty_i32, empty_i32, empty_i32, 0 + + req_idx = token_to_req_indices[:num_tokens].to(torch.int64) + new_group = torch.ones(num_tokens, dtype=torch.bool, device=device) + if num_tokens > 1: + new_group[1:] = req_idx[1:] != req_idx[:-1] + group_starts = torch.nonzero(new_group, as_tuple=False).flatten() + group_ends = torch.empty_like(group_starts) + group_ends[:-1] = group_starts[1:] + group_ends[-1] = num_tokens + group_lengths = group_ends - group_starts + group_max_lens = compressed_lens[group_ends - 1].to(torch.int32) + + cu_seq_lens = torch.empty( + group_starts.numel() + 1, + dtype=torch.int32, + device=device, + ) + cu_seq_lens[:1] = 0 + torch.cumsum(group_max_lens, dim=0, out=cu_seq_lens[1:]) + total_k = int(cu_seq_lens[-1].item()) + row_lens = compressed_lens.to(torch.int32) + + group_for_token = torch.repeat_interleave( + torch.arange(group_starts.numel(), device=device, dtype=torch.int64), + group_lengths.to(torch.int64), + output_size=num_tokens, + ) + cu_start = cu_seq_lens[:-1][group_for_token] + cu_end = cu_start + row_lens + max_len = int(group_max_lens.max().item()) if group_max_lens.numel() else 0 + if total_k <= 0: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, cu_start, cu_end, row_lens, max_len + + group_ids = torch.repeat_interleave( + torch.arange(group_starts.numel(), device=device, dtype=torch.int64), + group_max_lens.to(torch.int64), + output_size=total_k, + ) + group_bases = cu_seq_lens[:-1][group_ids].to(torch.int64) + local = torch.arange(total_k, device=device, dtype=torch.int64) - group_bases + req_for_k = req_idx[group_starts][group_ids] + pages = torch.div(local, cache_block_size, rounding_mode="floor") + page_offsets = local % cache_block_size + page_ids = block_table[req_for_k, pages.long()].to(torch.int64) + slots = page_ids * cache_block_size + page_offsets + return slots, cu_start, cu_end, row_lens, max_len + + +def _deepseek_v4_indexer_prefill_request_gather_plan( + *, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + req_start: int, + req_end: int, + query_start: int, + query_end: int, + build_slots: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + device = block_table.device + num_rows = max(0, int(query_end) - int(query_start)) + if num_rows == 0 or req_end <= req_start: + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, empty_i32, empty_i32, empty_i32, 0 + + seq_lens_list = ( + seq_lens_cpu.detach().cpu().to(torch.int64)[req_start:req_end].tolist() + ) + query_lens_list = ( + query_lens_cpu.detach().cpu().to(torch.int64)[req_start:req_end].tolist() + ) + if len(seq_lens_list) != len(query_lens_list): + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, empty_i32, empty_i32, empty_i32, 0 + + ratio = max(1, int(compress_ratio)) + seq_lens_list = [max(0, int(x)) for x in seq_lens_list] + query_lens_list = [max(0, int(x)) for x in query_lens_list] + compressed_lens_list = [seq_len // ratio for seq_len in seq_lens_list] + total_k = sum(compressed_lens_list) + + query_offsets: list[int] = [0] + for query_len in query_lens_list: + query_offsets.append(query_offsets[-1] + query_len) + + req_local_list: list[int] = [] + row_lens_list: list[int] = [] + req_local = 0 + last_req = max(0, len(query_lens_list) - 1) + for row_offset in range(int(query_start), int(query_end)): + while req_local < last_req and row_offset >= query_offsets[req_local + 1]: + req_local += 1 + local_query_offset = row_offset - query_offsets[req_local] + prefix_len = max(0, seq_lens_list[req_local] - query_lens_list[req_local]) + row_lens_list.append((prefix_len + local_query_offset + 1) // ratio) + req_local_list.append(req_local) + max_len = max(row_lens_list) if row_lens_list else 0 + + compressed_lens = torch.tensor( + compressed_lens_list, + dtype=torch.int64, + device=device, + ) + + cu_seq_lens = torch.empty( + compressed_lens.numel() + 1, + dtype=torch.int32, + device=device, + ) + cu_seq_lens[:1] = 0 + torch.cumsum(compressed_lens.to(torch.int32), dim=0, out=cu_seq_lens[1:]) + + req_local_tensor = torch.tensor(req_local_list, dtype=torch.int64, device=device) + row_lens = torch.tensor(row_lens_list, dtype=torch.int32, device=device) + cu_start = cu_seq_lens[:-1][req_local_tensor] + cu_end = cu_start + row_lens + + if total_k <= 0 or not build_slots: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, cu_start, cu_end, row_lens, max_len + + req_ids = torch.repeat_interleave( + torch.arange(req_start, req_end, device=device, dtype=torch.int64), + compressed_lens, + output_size=total_k, + ) + req_local_for_k = req_ids - int(req_start) + group_bases = cu_seq_lens[:-1][req_local_for_k].to(torch.int64) + local = torch.arange(total_k, device=device, dtype=torch.int64) - group_bases + pages = torch.div(local, cache_block_size, rounding_mode="floor") + page_offsets = local % cache_block_size + page_ids = block_table[req_ids, pages.long()].to(torch.int64) + slots = page_ids * cache_block_size + page_offsets + return slots, cu_start, cu_end, row_lens, max_len + + +def _deepseek_v4_indexer_prefill_chunk_total_rows( + *, + seq_lens_cpu: torch.Tensor, + compress_ratio: int, + req_start: int, + req_end: int, +) -> int: + ratio = max(1, int(compress_ratio)) + seq_lens = seq_lens_cpu.detach().cpu().to(torch.int64)[req_start:req_end].tolist() + return sum(max(0, int(seq_len)) // ratio for seq_len in seq_lens) + + +def _deepseek_v4_empty_indexer_prefill_metadata( + device: torch.device, +) -> _DeepseekV4IndexerPrefillMetadata: + return _DeepseekV4IndexerPrefillMetadata( + chunk_bounds=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + chunk_plan=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + slots=torch.empty(0, dtype=torch.int64, device=device), + cu_seq_lens=torch.empty(0, dtype=torch.int32, device=device), + cu_start=torch.empty(0, dtype=torch.int32, device=device), + cu_end=torch.empty(0, dtype=torch.int32, device=device), + row_lens=torch.empty(0, dtype=torch.int32, device=device), + ) + + +def _deepseek_v4_indexer_prefill_metadata( + *, + metadata: Any, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + num_prefill_tokens: int, +) -> _DeepseekV4IndexerPrefillMetadata: + device = block_table.device + if num_prefill_tokens <= 0: + return _deepseek_v4_empty_indexer_prefill_metadata(device) + + seq_lens_cpu = getattr(metadata, "seq_lens_cpu", None) + query_lens_cpu = getattr(metadata, "query_lens_cpu", None) + num_prefill_reqs = int(getattr(metadata, "num_prefill_reqs", 0) or 0) + if seq_lens_cpu is None or query_lens_cpu is None or num_prefill_reqs <= 0: + return _deepseek_v4_empty_indexer_prefill_metadata(device) + + seq_lens_cpu = seq_lens_cpu[:num_prefill_reqs] + query_lens_cpu = query_lens_cpu[:num_prefill_reqs] + cache_key = (compress_ratio, cache_block_size, num_prefill_tokens) + cache = getattr(metadata, "prefill_indexer_plan_cache", None) + cached = cache.get(cache_key) if cache is not None else None + if cached is not None and cached.slots.device == device: + return cached + + chunks = _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + compress_ratio=compress_ratio, + num_tokens=num_prefill_tokens, + ) + if not chunks: + out = _deepseek_v4_empty_indexer_prefill_metadata(device) + if cache is not None: + cache[cache_key] = out + return out + + chunk_bounds_rows: list[list[int]] = [] + chunk_plan_rows: list[list[int]] = [] + slot_parts: list[torch.Tensor] = [] + cu_seq_lens_parts: list[torch.Tensor] = [] + cu_start_parts: list[torch.Tensor] = [] + cu_end_parts: list[torch.Tensor] = [] + row_lens_parts: list[torch.Tensor] = [] + slot_offset = 0 + cu_seq_offset = 0 + row_offset = 0 + for chunk in chunks: + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + query_start=chunk.query_start, + query_end=chunk.query_end, + build_slots=False, + ) + ) + slot_count = _deepseek_v4_indexer_prefill_chunk_total_rows( + seq_lens_cpu=seq_lens_cpu, + compress_ratio=compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + ) + compressed_lens = torch.div( + seq_lens_cpu[chunk.req_start : chunk.req_end].to( + dtype=torch.int32, + device=device, + ), + max(1, int(compress_ratio)), + rounding_mode="floor", + ) + cu_seq_lens = torch.empty( + compressed_lens.numel() + 1, + dtype=torch.int32, + device=device, + ) + cu_seq_lens[:1] = 0 + torch.cumsum(compressed_lens, dim=0, out=cu_seq_lens[1:]) + slot_end = slot_offset + slot_count + cu_seq_end = cu_seq_offset + cu_seq_lens.numel() + row_end = row_offset + row_lens.numel() + chunk_bounds_rows.append( + [ + chunk.token_start, + chunk.token_end, + chunk.req_start, + chunk.req_end, + chunk.query_start, + chunk.query_end, + 1 if chunk.skip_kv_gather else 0, + ] + ) + chunk_plan_rows.append( + [ + slot_offset, + slot_end, + row_offset, + row_end, + max_len, + cu_seq_offset, + cu_seq_end, + ] + ) + if slots.numel() > 0: + slot_parts.append(slots) + cu_seq_lens_parts.append(cu_seq_lens) + cu_start_parts.append(cu_start) + cu_end_parts.append(cu_end) + row_lens_parts.append(row_lens) + slot_offset = slot_end + cu_seq_offset = cu_seq_end + row_offset = row_end + + out = _DeepseekV4IndexerPrefillMetadata( + chunk_bounds=torch.tensor(chunk_bounds_rows, dtype=torch.int64, device="cpu"), + chunk_plan=torch.tensor(chunk_plan_rows, dtype=torch.int64, device="cpu"), + slots=( + torch.cat(slot_parts, dim=0) + if slot_parts + else torch.empty(0, dtype=torch.int64, device=device) + ), + cu_seq_lens=( + torch.cat(cu_seq_lens_parts, dim=0) + if cu_seq_lens_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + cu_start=( + torch.cat(cu_start_parts, dim=0) + if cu_start_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + cu_end=( + torch.cat(cu_end_parts, dim=0) + if cu_end_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + row_lens=( + torch.cat(row_lens_parts, dim=0) + if row_lens_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + ) + if cache is not None: + cache[cache_key] = out + return out + + def _deepseek_v4_indexer_topk_from_logits( logits: torch.Tensor, lengths: torch.Tensor, topk_tokens: int, *, + next_n: int = 1, preserve_topk_order: bool = False, - out: torch.Tensor | None = None, + sort_preserved_topk: Optional[bool] = None, + row_starts: Optional[torch.Tensor] = None, + row_ends: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - num_tokens = lengths.numel() + lengths_for_kernel = lengths.to(torch.int32).contiguous() + length_rows = lengths_for_kernel.reshape(-1) + num_tokens = length_rows.numel() if out is None: topk = torch.empty( (num_tokens, topk_tokens), @@ -908,11 +1513,33 @@ def _deepseek_v4_indexer_topk_from_logits( if max_len <= 0: return topk + row_starts_for_kernel: Optional[torch.Tensor] = None + row_ends_for_kernel: Optional[torch.Tensor] = None + if row_starts is not None or row_ends is not None: + if row_starts is None: + row_starts_for_kernel = torch.zeros_like(length_rows) + else: + row_starts_for_kernel = row_starts.to( + device=logits.device, dtype=torch.int32 + ).reshape(-1) + if row_ends is None: + row_ends_for_kernel = row_starts_for_kernel + length_rows + else: + row_ends_for_kernel = row_ends.to( + device=logits.device, dtype=torch.int32 + ).reshape(-1) + length_rows = (row_ends_for_kernel - row_starts_for_kernel).clamp_min(0) + + if sort_preserved_topk is None: + sort_preserved_topk = False + if preserve_topk_order: prefill_topk = _deepseek_v4_indexer_topk_from_logits_prefill_op( logits, - lengths.to(torch.int32).reshape(-1), + length_rows, topk_tokens, + row_starts=row_starts_for_kernel, + row_ends=row_ends_for_kernel, out=topk, ) if prefill_topk is not None: @@ -923,29 +1550,55 @@ def _deepseek_v4_indexer_topk_from_logits( fast_topk_v2( logits.contiguous(), - lengths.to(torch.int32).contiguous(), + lengths_for_kernel, topk, topk_tokens, + next_n, ) return topk offsets = torch.arange(max_len, device=logits.device, dtype=torch.int64) + if row_starts_for_kernel is not None and row_ends_for_kernel is not None: + row_starts_i64 = row_starts_for_kernel.to(torch.int64) + row_ends_i64 = row_ends_for_kernel.to(torch.int64) + valid = (offsets[None, :] >= row_starts_i64[:, None]) & ( + offsets[None, :] < row_ends_i64[:, None] + ) + masked_logits = logits.masked_fill(~valid, -float("inf")) + selected = min(int(length_rows.max().item()), topk_tokens) + if selected <= 0: + return topk + values, indices = torch.topk( + masked_logits, + k=selected, + dim=-1, + sorted=bool(sort_preserved_topk), + ) + indices = indices - row_starts_i64[:, None] + indices = torch.where( + torch.isfinite(values), + indices, + torch.full_like(indices, -1), + ).to(torch.int32) + topk[:, :selected] = indices + return topk + masked_logits = logits.masked_fill( - offsets[None, :] >= lengths[:, None], -float("inf") + offsets[None, :] >= length_rows[:, None], -float("inf") ) if preserve_topk_order: - for raw_len in torch.unique(lengths).tolist(): + for raw_len in torch.unique(length_rows).tolist(): num_compressed = int(raw_len) selected = min(num_compressed, topk_tokens) if selected <= 0: continue - row_mask = lengths == num_compressed + row_mask = length_rows == num_compressed token_topk = torch.topk( masked_logits[row_mask, :num_compressed], k=selected, dim=-1, - sorted=False, + sorted=sort_preserved_topk, ).indices topk[row_mask, :selected] = token_topk.to(torch.int32) return topk @@ -961,24 +1614,22 @@ def _deepseek_v4_indexer_topk_from_logits( return topk -_DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = False -_DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED = False - - def _deepseek_v4_prefill_topk_op_available() -> bool: - global _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE global _DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED + global _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE if _DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED: return _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE try: - from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( - has_indexer_topk_prefill, - ) + import tokenspeed_kernel.thirdparty.trtllm # noqa: F401 except Exception: _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = False else: - _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = bool(has_indexer_topk_prefill()) + trtllm_ops = getattr(torch.ops, "trtllm", None) + _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = trtllm_ops is not None and hasattr( + trtllm_ops, + "indexer_topk_prefill", + ) _DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED = True return _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE @@ -988,9 +1639,11 @@ def _deepseek_v4_indexer_topk_from_logits_prefill_op( length_rows: torch.Tensor, topk_tokens: int, *, + row_starts: Optional[torch.Tensor] = None, + row_ends: Optional[torch.Tensor] = None, out: torch.Tensor, ) -> Optional[torch.Tensor]: - """Use the local CUDA prefill selector when the extension is available.""" + """Use the local TRT-LLM CUDA prefill selector.""" if not logits.is_cuda or logits.dtype != torch.float32: return None @@ -1001,49 +1654,48 @@ def _deepseek_v4_indexer_topk_from_logits_prefill_op( if num_rows == 0: return out[:0] logits = logits.contiguous() - row_starts = torch.zeros(num_rows, device=logits.device, dtype=torch.int32) - row_ends = length_rows.to(device=logits.device, dtype=torch.int32).reshape(-1) + if row_starts is None: + row_starts_for_kernel = torch.zeros( + num_rows, + device=logits.device, + dtype=torch.int32, + ) + else: + row_starts_for_kernel = ( + row_starts.to( + device=logits.device, + dtype=torch.int32, + ) + .reshape(-1) + .contiguous() + ) + if row_ends is None: + row_ends_for_kernel = ( + row_starts_for_kernel + + length_rows.to(device=logits.device, dtype=torch.int32).reshape(-1) + ).contiguous() + else: + row_ends_for_kernel = ( + row_ends.to( + device=logits.device, + dtype=torch.int32, + ) + .reshape(-1) + .contiguous() + ) topk = out[:num_rows] topk.fill_(-1) - from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( - indexer_topk_prefill, - ) - - indexer_topk_prefill( + torch.ops.trtllm.indexer_topk_prefill( logits, - row_starts, - row_ends.contiguous(), + row_starts_for_kernel, + row_ends_for_kernel, topk, topk_tokens, ) return topk -def _deepseek_v4_indexer_ascending_prefill_topk( - positions: torch.Tensor, - compress_ratio: int, - topk_tokens: int, -) -> torch.Tensor: - num_tokens = positions.numel() - offsets = torch.arange(topk_tokens, device=positions.device, dtype=torch.int32) - lengths = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp(min=0, max=topk_tokens) - return torch.where( - offsets[None, :] < lengths[:, None], - offsets[None, :], - torch.full( - (num_tokens, topk_tokens), - -1, - device=positions.device, - dtype=torch.int32, - ), - ) - - def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( *, cache_2d: torch.Tensor, @@ -1068,12 +1720,15 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( device=positions.device, dtype=torch.int32, ) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp_min(0) - max_len = int(compressed_lens.max().item()) + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_gather_plan( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + ) + ) if max_len <= 0: return torch.full( (num_tokens, topk_tokens), @@ -1081,27 +1736,12 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( device=positions.device, dtype=torch.int32, ) - - offsets = torch.arange(max_len, device=positions.device, dtype=torch.int64) - local = offsets[None, :].expand(num_tokens, -1) - valid = local < compressed_lens[:, None] - req_idx = token_to_req_indices[:num_tokens].to(torch.int64) - pages = torch.div(local, cache_block_size, rounding_mode="floor") - page_offsets = local % cache_block_size - page_ids = block_table[req_idx[:, None], pages.long()].to(torch.int64) - slots = page_ids * cache_block_size + page_offsets - with deepseek_v4_profile_scope("indexer_topk_prefill_gather_mxfp4"): k_values, k_scales = _deepseek_v4_gather_indexer_mxfp4_cache( cache_2d, - slots[valid], + slots, cache_block_size, ) - row_lens = valid.sum(dim=1, dtype=torch.int32) - cu_end = torch.cumsum(row_lens, dim=0, dtype=torch.int32) - cu_start = torch.empty_like(cu_end) - cu_start[0] = 0 - cu_start[1:] = cu_end[:-1] try: with deepseek_v4_profile_scope("indexer_topk_prefill_deepgemm_logits"): @@ -1127,97 +1767,994 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( ) -def _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( +def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + *, + cache_2d: torch.Tensor, + gather_plan: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int], + index_q: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cache_block_size: int, + topk_tokens: int, + preserve_topk_order: bool, + gathered_k: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + gather_workspace: Optional[tuple[torch.Tensor, torch.Tensor]] = None, +) -> tuple[Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + q_values, q_scales = index_q + if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): + return None, gathered_k + + num_tokens = q_values.shape[0] + slots, cu_start, cu_end, row_lens, max_len = gather_plan + if num_tokens == 0: + return ( + torch.empty( + (0, topk_tokens), + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + if max_len <= 0: + return ( + torch.full( + (num_tokens, topk_tokens), + -1, + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + + if gathered_k is None: + with deepseek_v4_profile_scope("indexer_topk_prefill_gather_mxfp4"): + gathered_k = _deepseek_v4_gather_indexer_mxfp4_cache( + cache_2d, + slots, + cache_block_size, + out=gather_workspace, + ) + k_values, k_scales = gathered_k + + try: + with deepseek_v4_profile_scope("indexer_topk_prefill_deepgemm_logits"): + logits = deep_gemm.fp8_fp4_mqa_logits( + q=(q_values.contiguous().view(torch.int8), q_scales.contiguous()), + kv=(k_values.contiguous(), k_scales.contiguous()), + weights=weights.contiguous(), + cu_seq_len_k_start=cu_start, + cu_seq_len_k_end=cu_end, + clean_logits=False, + max_seqlen_k=max_len, + logits_dtype=torch.float32, + ) + except RuntimeError: + return None, gathered_k + + with deepseek_v4_profile_scope("indexer_topk_prefill_select"): + return ( + _deepseek_v4_indexer_topk_from_logits( + logits, + row_lens, + topk_tokens, + preserve_topk_order=preserve_topk_order, + ), + gathered_k, + ) + + +def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_contract( + *, + cache_2d: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + cu_start: torch.Tensor, + cu_end: torch.Tensor, + row_lens: torch.Tensor, + max_len: int, + index_q: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cache_block_size: int, + topk_tokens: int, + preserve_topk_order: bool, + gathered_k: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + gather_workspace: Optional[tuple[torch.Tensor, torch.Tensor]] = None, +) -> tuple[Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + q_values, q_scales = index_q + if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): + return None, gathered_k + + num_tokens = q_values.shape[0] + if num_tokens == 0: + return ( + torch.empty( + (0, topk_tokens), + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + if max_len <= 0: + return ( + torch.full( + (num_tokens, topk_tokens), + -1, + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + + if gathered_k is None: + with deepseek_v4_profile_scope("indexer_topk_prefill_gather_paged_mxfp4"): + gathered_k = _deepseek_v4_gather_paged_indexer_mxfp4_cache( + cache_2d, + block_table, + cu_seq_lens, + cache_block_size, + out=gather_workspace, + ) + k_values, k_scales = gathered_k + + try: + with deepseek_v4_profile_scope("indexer_topk_prefill_deepgemm_logits"): + logits = deep_gemm.fp8_fp4_mqa_logits( + q=(q_values.contiguous().view(torch.int8), q_scales.contiguous()), + kv=(k_values.contiguous(), k_scales.contiguous()), + weights=weights.contiguous(), + cu_seq_len_k_start=cu_start, + cu_seq_len_k_end=cu_end, + clean_logits=False, + max_seqlen_k=max_len, + logits_dtype=torch.float32, + ) + except RuntimeError: + return None, gathered_k + + with deepseek_v4_profile_scope("indexer_topk_prefill_select"): + return ( + _deepseek_v4_indexer_topk_from_logits( + logits, + row_lens, + topk_tokens, + preserve_topk_order=preserve_topk_order, + ), + gathered_k, + ) + + +def _deepseek_v4_indexer_decode_metadata( + *, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + metadata: Optional[Any] = None, + is_valid_token: Optional[torch.Tensor] = None, +) -> _DeepseekV4IndexerDecodeMetadata: + num_tokens = positions.numel() + key = (int(compress_ratio), int(cache_block_size), int(num_tokens)) + cache = getattr(metadata, "decode_indexer_plan_cache", None) + refreshed_keys = getattr(metadata, "decode_indexer_plan_refreshed_keys", None) + cached = cache.get(key) if cache is not None else None + # Hot path: the attention metadata builder hook + # (_refresh_decode_indexer_plan_cache in backends/deepseek_v4.py) pre-builds + # the plan tensors at metadata setup time and adds the key to + # refreshed_keys. The metadata builder also clears refreshed_keys at the + # start of each refresh so a stale entry from a previous step cannot + # cause an early-return with capture-time data. By returning the cached plan + # here, the per-layer + # `run_indexer` call dispatched into `_deepseek_v4_maybe_execute_in_parallel` + # becomes a pure read, eliminating the cross-stream allocator race + # against `insert_and_compress` on aux_stream. + if cached is not None and refreshed_keys is not None and key in refreshed_keys: + return cached + + if num_tokens == 0: + context_lens = torch.empty((0, 1), dtype=torch.int32, device=positions.device) + block_tables = torch.empty( + (0, 1), + dtype=torch.int32, + device=block_table.device, + ) + plan = _DeepseekV4IndexerDecodeMetadata(context_lens, block_tables, 0) + if cache is not None: + cache[key] = plan + if refreshed_keys is not None: + refreshed_keys.add(key) + return plan + + rows = int(block_table.shape[0]) if block_table.ndim >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.ndim >= 2 else 0 + max_len = _deepseek_v4_indexer_decode_max_len( + block_table, + cache_block_size, + compress_ratio, + ) + max_blocks = max(1, (max_len + cache_block_size - 1) // cache_block_size) + + expected_context_shape = (num_tokens, 1) + expected_block_shape = (num_tokens, max_blocks) + if ( + cached is None + or cached.context_lens.shape != expected_context_shape + or cached.context_lens.device != positions.device + or cached.context_lens.dtype != torch.int32 + or cached.block_table.shape != expected_block_shape + or cached.block_table.device != block_table.device + or cached.block_table.dtype != torch.int32 + ): + context_lens = torch.empty( + expected_context_shape, + dtype=torch.int32, + device=positions.device, + ) + block_tables = torch.empty( + expected_block_shape, + dtype=torch.int32, + device=block_table.device, + ) + plan = _DeepseekV4IndexerDecodeMetadata( + context_lens=context_lens, + block_table=block_tables, + max_context_len=max_len, + ) + if cache is not None: + cache[key] = plan + else: + plan = cached + plan.max_context_len = max_len + + if rows <= 0 or cols <= 0: + plan.context_lens.zero_() + plan.block_table.zero_() + plan.max_context_len = 0 + else: + deepseek_v4_indexer_decode_metadata_compute( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + max_blocks=max_blocks, + out_context_lens=plan.context_lens, + out_block_tables=plan.block_table, + ) + if is_valid_token is None: + is_valid_token = getattr(metadata, "is_valid_token", None) + if is_valid_token is not None: + valid = is_valid_token[:num_tokens].to( + device=plan.context_lens.device, + dtype=torch.bool, + ) + with torch.inference_mode(): + plan.context_lens.masked_fill_(~valid.view(num_tokens, 1), 0) + plan.block_table.masked_fill_( + ~valid.to(device=plan.block_table.device).view(num_tokens, 1), + 0, + ) + if refreshed_keys is not None: + refreshed_keys.add(key) + return plan + + +def _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( + *, + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + index_q: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + compress_ratio: int, + topk_tokens: int, + metadata: Optional[Any] = None, + schedule_metadata: Optional[torch.Tensor] = None, + decode_context_lens: Optional[torch.Tensor] = None, + decode_block_table: Optional[torch.Tensor] = None, + decode_max_context_len: Optional[int] = None, + is_valid_token: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + q_values, q_scales = index_q + if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): + return None + + num_tokens = positions.numel() + if num_tokens == 0: + if out is not None: + return out[:0] + return torch.empty((0, topk_tokens), device=positions.device, dtype=torch.int32) + if decode_context_lens is not None and decode_block_table is not None: + context_lens = decode_context_lens + block_tables = decode_block_table + max_len = ( + int(decode_max_context_len) + if decode_max_context_len is not None + else int(context_lens.max().item()) + ) + else: + decode_plan = _deepseek_v4_indexer_decode_metadata( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + metadata=metadata, + is_valid_token=is_valid_token, + ) + context_lens = decode_plan.context_lens + block_tables = decode_plan.block_table + max_len = decode_plan.max_context_len + topk = ( + torch.empty( + (num_tokens, topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + if out is None + else out[:num_tokens] + ) + if max_len <= 0: + topk.fill_(-1) + return topk + kv_cache = _deepseek_v4_indexer_mxfp4_cache_view(cache_2d, cache_block_size) + schedule_key = (compress_ratio, cache_block_size, num_tokens) + schedule_cache = getattr(metadata, "decode_indexer_schedule_metadata", None) + if schedule_metadata is None: + schedule_metadata = ( + schedule_cache.get(schedule_key) if schedule_cache is not None else None + ) + if schedule_metadata is None: + with deepseek_v4_profile_scope("indexer_decode_schedule_metadata"): + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, + cache_block_size, + deep_gemm.get_num_sms(), + ) + if schedule_cache is not None: + schedule_cache[schedule_key] = schedule_metadata + + try: + with deepseek_v4_profile_scope("indexer_decode_deepgemm_logits"): + logits = deep_gemm.fp8_fp4_paged_mqa_logits( + q=( + q_values.contiguous().view(torch.int8).unsqueeze(1), + q_scales.contiguous().unsqueeze(1), + ), + kv_cache=kv_cache, + weights=weights.contiguous(), + context_lens=context_lens, + block_table=block_tables, + schedule_meta=schedule_metadata, + max_context_len=max_len, + clean_logits=False, + logits_dtype=torch.float32, + ) + except RuntimeError: + return None + + with deepseek_v4_profile_scope("indexer_decode_topk"): + return _deepseek_v4_indexer_topk_from_logits( + logits, + context_lens, + topk_tokens, + next_n=1, + out=out, + ) + + +def _deepseek_v4_indexer_decode_schedule_metadata( + *, + positions: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + metadata: Optional[Any], + context_lens: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + if positions.numel() == 0: + return None + if getattr(deep_gemm, "get_paged_mqa_logits_metadata", None) is None: + return None + + num_tokens = positions.numel() + if context_lens is None: + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + context_lens = compressed_lens.to(torch.int32).view(num_tokens, 1).contiguous() + schedule_key = (compress_ratio, cache_block_size, num_tokens) + schedule_cache = getattr(metadata, "decode_indexer_schedule_metadata", None) + schedule_metadata = ( + schedule_cache.get(schedule_key) if schedule_cache is not None else None + ) + + with deepseek_v4_profile_scope("indexer_decode_schedule_metadata"): + refreshed = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, + cache_block_size, + deep_gemm.get_num_sms(), + ) + if schedule_metadata is not None: + if ( + schedule_metadata.shape == refreshed.shape + and schedule_metadata.device == refreshed.device + and schedule_metadata.dtype == refreshed.dtype + ): + with torch.inference_mode(): + schedule_metadata.copy_(refreshed) + return schedule_metadata + if schedule_cache is not None: + schedule_cache[schedule_key] = refreshed + return refreshed + schedule_metadata = refreshed + if schedule_cache is not None: + schedule_cache[schedule_key] = schedule_metadata + return schedule_metadata + + +def _deepseek_v4_sparse_attn_indexer_native( + *, + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: Optional[torch.Tensor], + decode_context_lens: Optional[torch.Tensor], + decode_block_table: Optional[torch.Tensor], + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + topk_tokens: int, + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + total_tokens = positions.numel() + topk_out = topk_indices_buffer[:total_tokens] + topk_out.fill_(-1) + if total_tokens == 0: + return topk_out + + cache_reader = ( + read_deepseek_v4_indexer_mxfp4_cache + if use_fp4_cache + else read_deepseek_v4_indexer_fp8_cache + ) + + def fill_prefill() -> None: + if num_prefill_tokens <= 0: + return + + prefill_positions = positions[:num_prefill_tokens] + if prefill_chunk_bounds.numel() > 0: + gather_cache_key = None + gathered_k = None + num_chunks = prefill_chunk_bounds.shape[0] + for chunk_idx in range(num_chunks): + bounds = prefill_chunk_bounds[chunk_idx] + plan = prefill_chunk_plan[chunk_idx] + token_start = int(bounds[0].item()) + token_end = int(bounds[1].item()) + req_start = int(bounds[2].item()) + req_end = int(bounds[3].item()) + skip_kv_gather = bool(int(bounds[6].item())) + slot_start = int(plan[0].item()) + slot_end = int(plan[1].item()) + row_start = int(plan[2].item()) + row_end = int(plan[3].item()) + max_len = int(plan[4].item()) + cu_seq_start = int(plan[5].item()) if plan.numel() > 5 else 0 + cu_seq_end = int(plan[6].item()) if plan.numel() > 6 else 0 + gather_rows = max(0, slot_end - slot_start) + gather_plan = ( + prefill_slots[slot_start:slot_end], + prefill_cu_start[row_start:row_end], + prefill_cu_end[row_start:row_end], + prefill_row_lens[row_start:row_end], + max_len, + ) + gather_workspace = None + if ( + prefill_gather_values_workspace.numel() > 0 + and prefill_gather_scales_workspace.numel() > 0 + and gather_rows <= prefill_gather_values_workspace.shape[0] + and gather_rows <= prefill_gather_scales_workspace.shape[0] + ): + gather_workspace = ( + prefill_gather_values_workspace[:gather_rows], + prefill_gather_scales_workspace[:gather_rows], + ) + topk = None + if has_packed_q: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + key = (req_start, req_end) + reuse_k = ( + gathered_k + if skip_kv_gather and gather_cache_key == key + else None + ) + if ( + prefill_cu_seq_lens.numel() > 0 + and cu_seq_end > cu_seq_start + ): + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_contract( + cache_2d=cache_2d, + block_table=block_table[req_start:req_end], + cu_seq_lens=prefill_cu_seq_lens[ + cu_seq_start:cu_seq_end + ], + cu_start=prefill_cu_start[row_start:row_end], + cu_end=prefill_cu_end[row_start:row_end], + row_lens=prefill_row_lens[row_start:row_end], + max_len=max_len, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[token_start:token_end], + packed_q_scales[token_start:token_end], + ), + weights=packed_weights[token_start:token_end], + topk_tokens=topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + gather_workspace=gather_workspace, + ) + ) + else: + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + cache_2d=cache_2d, + gather_plan=gather_plan, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[token_start:token_end], + packed_q_scales[token_start:token_end], + ), + weights=packed_weights[token_start:token_end], + topk_tokens=topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + gather_workspace=gather_workspace, + ) + ) + if topk is not None and next_gathered_k is not None: + gather_cache_key = key + gathered_k = next_gathered_k + if topk is None and fallback_index_q.numel() > 0: + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=cache_2d, + positions=prefill_positions[token_start:token_end], + token_to_req_indices=token_to_req_indices[ + token_start:token_end + ], + block_table=block_table, + cache_block_size=cache_block_size, + index_q=fallback_index_q[token_start:token_end], + weights=fallback_weights[token_start:token_end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + preserve_topk_order=True, + ) + if topk is None: + raise RuntimeError( + "DeepSeek V4 sparse indexer prefill DeepGEMM path failed " + "without a prepared fallback." + ) + if topk is not None: + topk_out[token_start:token_end].copy_(topk) + return + + topk_chunks = [] + for start, end in _deepseek_v4_indexer_prefill_topk_chunks( + prefill_positions, + compress_ratio, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + ): + topk = None + if has_packed_q: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( + cache_2d=cache_2d, + positions=prefill_positions[start:end], + token_to_req_indices=token_to_req_indices[start:end], + block_table=block_table, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[start:end], + packed_q_scales[start:end], + ), + weights=packed_weights[start:end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + preserve_topk_order=True, + ) + if topk is None and fallback_index_q.numel() > 0: + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=cache_2d, + positions=prefill_positions[start:end], + token_to_req_indices=token_to_req_indices[start:end], + block_table=block_table, + cache_block_size=cache_block_size, + index_q=fallback_index_q[start:end], + weights=fallback_weights[start:end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + preserve_topk_order=True, + ) + if topk is None: + raise RuntimeError( + "DeepSeek V4 sparse indexer prefill DeepGEMM path failed " + "without a prepared fallback." + ) + if topk is not None: + topk_chunks.append(topk) + if topk_chunks: + with deepseek_v4_profile_scope("indexer_topk_cat_prefill"): + topk_out[:num_prefill_tokens].copy_(torch.cat(topk_chunks, dim=0)) + + def fill_decode() -> None: + if num_decode_tokens <= 0: + return + + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_token_to_req = token_to_req_indices[decode_start:decode_end] + decode_out = topk_out[decode_start:decode_end] + topk = None + if has_packed_q: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_decode"): + topk = _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( + cache_2d=cache_2d, + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=block_table, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[decode_start:decode_end], + packed_q_scales[decode_start:decode_end], + ), + weights=packed_weights[decode_start:decode_end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + schedule_metadata=decode_schedule_metadata, + decode_context_lens=decode_context_lens, + decode_block_table=decode_block_table, + decode_max_context_len=decode_max_context_len, + out=decode_out, + ) + if topk is None and fallback_index_q.shape[0] >= decode_end: + with deepseek_v4_profile_scope("indexer_topk_fallback_decode"): + _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=cache_2d, + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=block_table, + cache_block_size=cache_block_size, + index_q=fallback_index_q[decode_start:decode_end], + weights=fallback_weights[decode_start:decode_end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + out=decode_out, + ) + topk = decode_out + if topk is None: + raise RuntimeError( + "DeepSeek V4 sparse indexer decode DeepGEMM path failed " + "without a prepared fallback." + ) + + fill_prefill() + fill_decode() + return topk_out + + +def _deepseek_v4_sparse_attn_indexer_op( + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: torch.Tensor, + decode_context_lens: torch.Tensor, + decode_block_table: torch.Tensor, + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + topk_tokens: int, + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + schedule_metadata = ( + decode_schedule_metadata if decode_schedule_metadata.numel() > 0 else None + ) + context_lens = decode_context_lens if decode_context_lens.numel() > 0 else None + decode_blocks = decode_block_table if decode_block_table.numel() > 0 else None + return _deepseek_v4_sparse_attn_indexer_native( + cache_2d=cache_2d, + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + prefill_chunk_bounds=prefill_chunk_bounds, + prefill_chunk_plan=prefill_chunk_plan, + prefill_slots=prefill_slots, + prefill_cu_seq_lens=prefill_cu_seq_lens, + prefill_cu_start=prefill_cu_start, + prefill_cu_end=prefill_cu_end, + prefill_row_lens=prefill_row_lens, + packed_q_values=packed_q_values, + packed_q_scales=packed_q_scales, + packed_weights=packed_weights, + fallback_index_q=fallback_index_q, + fallback_weights=fallback_weights, + decode_schedule_metadata=schedule_metadata, + decode_context_lens=context_lens, + decode_block_table=decode_blocks, + decode_max_context_len=decode_max_context_len, + topk_indices_buffer=topk_indices_buffer, + prefill_gather_values_workspace=prefill_gather_values_workspace, + prefill_gather_scales_workspace=prefill_gather_scales_workspace, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_fp4_cache=use_fp4_cache, + has_packed_q=has_packed_q, + ) + + +def _deepseek_v4_sparse_attn_indexer_fake( + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: torch.Tensor, + decode_context_lens: torch.Tensor, + decode_block_table: torch.Tensor, + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + topk_tokens: int, + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + del ( + cache_2d, + positions, + token_to_req_indices, + block_table, + seq_lens_cpu, + query_lens_cpu, + prefill_chunk_bounds, + prefill_chunk_plan, + prefill_slots, + prefill_cu_seq_lens, + prefill_cu_start, + prefill_cu_end, + prefill_row_lens, + packed_q_values, + packed_q_scales, + packed_weights, + fallback_index_q, + fallback_weights, + decode_schedule_metadata, + decode_context_lens, + decode_block_table, + decode_max_context_len, + cache_block_size, + prefill_gather_values_workspace, + prefill_gather_scales_workspace, + compress_ratio, + topk_tokens, + num_prefill_tokens, + num_decode_tokens, + use_fp4_cache, + has_packed_q, + ) + return topk_indices_buffer + + +direct_register_custom_op( + op_name="deepseek_v4_sparse_attn_indexer", + op_func=_deepseek_v4_sparse_attn_indexer_op, + mutates_args=[ + "topk_indices_buffer", + "prefill_gather_values_workspace", + "prefill_gather_scales_workspace", + ], + fake_impl=_deepseek_v4_sparse_attn_indexer_fake, +) + + +def _deepseek_v4_sparse_attn_indexer( *, cache_2d: torch.Tensor, positions: torch.Tensor, token_to_req_indices: torch.Tensor, block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: Optional[torch.Tensor], + decode_context_lens: Optional[torch.Tensor], + decode_block_table: Optional[torch.Tensor], + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, cache_block_size: int, - index_q: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, compress_ratio: int, topk_tokens: int, - metadata: Any | None = None, - out: torch.Tensor | None = None, -) -> torch.Tensor | None: - q_values, q_scales = index_q - if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): - return None - - num_tokens = positions.numel() - if num_tokens == 0: - if out is not None: - return out[:0] - return torch.empty((0, topk_tokens), device=positions.device, dtype=torch.int32) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp_min(0) - if positions.is_cuda and torch.cuda.is_current_stream_capturing(): - max_len = _deepseek_v4_indexer_decode_max_len( + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + if decode_schedule_metadata is None: + decode_schedule_metadata = torch.empty( + 0, + dtype=torch.int32, + device=positions.device, + ) + if decode_context_lens is None: + decode_context_lens = torch.empty( + (0, 1), + dtype=torch.int32, + device=positions.device, + ) + if decode_block_table is None: + decode_block_table = torch.empty( + (0, 1), + dtype=block_table.dtype, + device=block_table.device, + ) + if positions.is_cuda: + return torch.ops.tokenspeed.deepseek_v4_sparse_attn_indexer( + cache_2d, + positions, + token_to_req_indices, block_table, + seq_lens_cpu, + query_lens_cpu, + prefill_chunk_bounds, + prefill_chunk_plan, + prefill_slots, + prefill_cu_seq_lens, + prefill_cu_start, + prefill_cu_end, + prefill_row_lens, + packed_q_values, + packed_q_scales, + packed_weights, + fallback_index_q, + fallback_weights, + decode_schedule_metadata, + decode_context_lens, + decode_block_table, + decode_max_context_len, + topk_indices_buffer, + prefill_gather_values_workspace, + prefill_gather_scales_workspace, cache_block_size, compress_ratio, + topk_tokens, + num_prefill_tokens, + num_decode_tokens, + use_fp4_cache, + has_packed_q, ) - else: - max_len = int(compressed_lens.max().item()) - if max_len <= 0: - topk = ( - torch.empty( - (num_tokens, topk_tokens), - device=positions.device, - dtype=torch.int32, - ) - if out is None - else out[:num_tokens] - ) - topk.fill_(-1) - return topk - - max_blocks = max(1, (max_len + cache_block_size - 1) // cache_block_size) - req_idx = token_to_req_indices[:num_tokens].to(torch.int64) - block_tables = block_table[req_idx, :max_blocks].contiguous() - context_lens = compressed_lens.to(torch.int32).view(num_tokens, 1).contiguous() - kv_cache = _deepseek_v4_indexer_mxfp4_cache_view(cache_2d, cache_block_size) - schedule_key = (compress_ratio, cache_block_size, num_tokens) - schedule_cache = getattr(metadata, "decode_indexer_schedule_metadata", None) - schedule_metadata = ( - schedule_cache.get(schedule_key) if schedule_cache is not None else None - ) - if schedule_metadata is None: - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( - context_lens, - cache_block_size, - deep_gemm.get_num_sms(), - ) - if schedule_cache is not None: - schedule_cache[schedule_key] = schedule_metadata - - try: - logits = deep_gemm.fp8_fp4_paged_mqa_logits( - q=( - q_values.contiguous().view(torch.int8).unsqueeze(1), - q_scales.contiguous().unsqueeze(1), - ), - kv_cache=kv_cache, - weights=weights.contiguous(), - context_lens=context_lens, - block_table=block_tables, - schedule_meta=schedule_metadata, - max_context_len=max_len, - clean_logits=False, - logits_dtype=torch.float32, - ) - except RuntimeError: - return None - - return _deepseek_v4_indexer_topk_from_logits( - logits, - compressed_lens.to(torch.int32), - topk_tokens, - out=out, + return _deepseek_v4_sparse_attn_indexer_native( + cache_2d=cache_2d, + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + prefill_chunk_bounds=prefill_chunk_bounds, + prefill_chunk_plan=prefill_chunk_plan, + prefill_slots=prefill_slots, + prefill_cu_seq_lens=prefill_cu_seq_lens, + prefill_cu_start=prefill_cu_start, + prefill_cu_end=prefill_cu_end, + prefill_row_lens=prefill_row_lens, + packed_q_values=packed_q_values, + packed_q_scales=packed_q_scales, + packed_weights=packed_weights, + fallback_index_q=fallback_index_q, + fallback_weights=fallback_weights, + decode_schedule_metadata=decode_schedule_metadata, + decode_context_lens=decode_context_lens, + decode_block_table=decode_block_table, + decode_max_context_len=decode_max_context_len, + topk_indices_buffer=topk_indices_buffer, + prefill_gather_values_workspace=prefill_gather_values_workspace, + prefill_gather_scales_workspace=prefill_gather_scales_workspace, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_fp4_cache=use_fp4_cache, + has_packed_q=has_packed_q, ) @@ -1225,6 +2762,10 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( DEEPSEEK_V4_FP8_BLOCK_SIZE = 128 _DEEPSEEK_V4_INDEXER_PREFILL_MAX_LOGITS_MB = 512 _DEEPSEEK_V4_FUSED_ROUTER_AVAILABLE = True +_DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED = False +_DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = False +_DEEPSEEK_V4_PAGED_GATHER_CHECKED = False +_DEEPSEEK_V4_PAGED_GATHER_AVAILABLE = False def _deepseek_v4_maybe_execute_in_parallel( @@ -1445,166 +2986,6 @@ def get(self, num_tokens: int, device: torch.device) -> torch.Tensor: return self.buffer[:num_tokens] -@triton.jit -def _deepseek_v4_stage_mega_moe_inputs_kernel( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_stride_m: tl.constexpr, - hidden_stride_k: tl.constexpr, - x_stride_m: tl.constexpr, - x_stride_k: tl.constexpr, - x_sf_stride_m: tl.constexpr, - x_sf_stride_k: tl.constexpr, - topk_ids_stride_m: tl.constexpr, - topk_ids_stride_k: tl.constexpr, - topk_weights_stride_m: tl.constexpr, - topk_weights_stride_k: tl.constexpr, - topk_idx_stride_m: tl.constexpr, - topk_idx_stride_k: tl.constexpr, - topk_weights_out_stride_m: tl.constexpr, - topk_weights_out_stride_k: tl.constexpr, - hidden_size: tl.constexpr, - top_k: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_K: tl.constexpr, - BLOCK_TOPK: tl.constexpr, -) -> None: - token_id = tl.program_id(0) - k_block_id = tl.program_id(1) - - k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - k_mask = k_offsets < hidden_size - hidden = tl.load( - hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, - mask=k_mask, - other=0.0, - ).to(tl.float32) - - num_groups: tl.constexpr = BLOCK_K // GROUP_K - hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(hidden_groups, axis=1) - amax = tl.maximum(amax, 1.0e-4) - - scale = amax / 448.0 - scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( - tl.uint32 - ) - scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) - rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) - - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups * (1.0 / rounded_scale)[:, None] - scaled = tl.reshape(scaled, [BLOCK_K]) - fp8 = scaled.to(tl.float8e4nv) - tl.store( - x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, - fp8, - mask=k_mask, - ) - - scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) - tl.store( - x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, - packed_scale, - ) - - if k_block_id == 0: - topk_offsets = tl.arange(0, BLOCK_TOPK) - topk_mask = topk_offsets < top_k - - ids = tl.load( - topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, - mask=topk_mask, - other=0, - ).to(tl.int64) - tl.store( - topk_idx_out - + token_id * topk_idx_stride_m - + topk_offsets * topk_idx_stride_k, - ids, - mask=topk_mask, - ) - - weights = tl.load( - topk_weights - + token_id * topk_weights_stride_m - + topk_offsets * topk_weights_stride_k, - mask=topk_mask, - other=0.0, - ) - tl.store( - topk_weights_out - + token_id * topk_weights_out_stride_m - + topk_offsets * topk_weights_out_stride_k, - weights, - mask=topk_mask, - ) - - -def _stage_deepseek_v4_mega_moe_inputs( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - x_fp8: torch.Tensor, - x_sf: torch.Tensor, - topk_idx_out: torch.Tensor, - topk_weights_out: torch.Tensor, -) -> None: - num_tokens, hidden_size = hidden_states.shape - if num_tokens == 0: - return - if hidden_size % DEEPSEEK_V4_FP8_BLOCK_SIZE != 0: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires hidden_size to be " - f"a multiple of {DEEPSEEK_V4_FP8_BLOCK_SIZE}." - ) - if topk_weights.shape != topk_ids.shape: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires topk_weights and " - "topk_ids to have the same shape." - ) - - block_k = DEEPSEEK_V4_FP8_BLOCK_SIZE - grid = (num_tokens, triton.cdiv(hidden_size, block_k)) - block_topk = triton.next_power_of_2(topk_ids.shape[1]) - _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_states.stride(0), - hidden_states.stride(1), - x_fp8.stride(0), - x_fp8.stride(1), - x_sf.stride(0), - x_sf.stride(1), - topk_ids.stride(0), - topk_ids.stride(1), - topk_weights.stride(0), - topk_weights.stride(1), - topk_idx_out.stride(0), - topk_idx_out.stride(1), - topk_weights_out.stride(0), - topk_weights_out.stride(1), - hidden_size, - topk_ids.shape[1], - BLOCK_K=block_k, - GROUP_K=32, - BLOCK_TOPK=block_topk, - num_warps=4, - ) - - DEEPSEEK_V4_MXFP4_BLOCK_SIZE = 32 @@ -2530,6 +3911,277 @@ def __init__( self.topk_tokens = int(config.index_topk) self.topk_buffer = topk_buffer self.softmax_scale = self.head_dim**-0.5 + value_bytes = DEEPSEEK_V4_INDEXER_DIM // 2 + scale_bytes = DEEPSEEK_V4_INDEXER_DIM // DEEPSEEK_V4_MXFP4_BLOCK_SIZE + self.register_buffer( + "_prefill_gather_values_workspace", + torch.empty((0, value_bytes), dtype=torch.uint8), + persistent=False, + ) + self.register_buffer( + "_prefill_gather_scales_workspace", + torch.empty((0, scale_bytes), dtype=torch.uint8), + persistent=False, + ) + + def _prefill_gather_workspace( + self, + rows: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + rows = max(0, int(rows)) + value_bytes = DEEPSEEK_V4_INDEXER_DIM // 2 + scale_bytes = DEEPSEEK_V4_INDEXER_DIM // DEEPSEEK_V4_MXFP4_BLOCK_SIZE + if ( + self._prefill_gather_values_workspace.device != device + or self._prefill_gather_values_workspace.shape[0] < rows + ): + self._prefill_gather_values_workspace = torch.empty( + (rows, value_bytes), + dtype=torch.uint8, + device=device, + ) + if ( + self._prefill_gather_scales_workspace.device != device + or self._prefill_gather_scales_workspace.shape[0] < rows + ): + self._prefill_gather_scales_workspace = torch.empty( + (rows, scale_bytes), + dtype=torch.uint8, + device=device, + ) + return ( + self._prefill_gather_values_workspace[:rows], + self._prefill_gather_scales_workspace[:rows], + ) + + def prepare_decode_metadata( + self, + *, + positions: torch.Tensor, + metadata: Any, + indexer_block_size: int, + ) -> None: + if not self.use_fp4_cache or not positions.is_cuda: + return + forward_mode = metadata.forward_mode + if forward_mode is not None and forward_mode.is_mixed(): + num_prefill_tokens = int(metadata.num_prefill_tokens) + num_decode_tokens = metadata.decode_token_count() + elif forward_mode is not None and forward_mode.is_decode(): + num_prefill_tokens = 0 + num_decode_tokens = positions.numel() + else: + return + if num_decode_tokens <= 0: + return + + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_valid_token = ( + metadata.is_valid_token[decode_start:decode_end] + if getattr(metadata, "is_valid_token", None) is not None + else None + ) + indexer_block_table = metadata.compressed_block_table( + self.compress_ratio, + indexer_block_size, + ) + decode_plan = _deepseek_v4_indexer_decode_metadata( + positions=decode_positions, + token_to_req_indices=metadata.token_to_req_indices[decode_start:decode_end], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + is_valid_token=decode_valid_token, + ) + _deepseek_v4_indexer_decode_schedule_metadata( + positions=decode_positions, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + context_lens=decode_plan.context_lens, + ) + + def _forward_sparse_indexer_custom_op( + self, + *, + hidden_states: torch.Tensor, + qr: torch.Tensor, + positions: torch.Tensor, + metadata: Any, + indexer_cache: torch.Tensor, + indexer_block_size: int, + cos_sin_cache: torch.Tensor, + ) -> Optional[torch.Tensor]: + if not self.use_fp4_cache or not positions.is_cuda: + return None + + forward_mode = metadata.forward_mode + total_tokens = positions.numel() + if total_tokens == 0: + return torch.empty( + (0, self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + if forward_mode is not None and forward_mode.is_mixed(): + num_prefill_tokens = int(metadata.num_prefill_tokens) + num_decode_tokens = metadata.decode_token_count() + elif forward_mode is not None and forward_mode.is_decode(): + num_prefill_tokens = 0 + num_decode_tokens = total_tokens + else: + num_prefill_tokens = total_tokens + num_decode_tokens = 0 + + with deepseek_v4_profile_scope("indexer_wq_b"): + index_q, _ = self.wq_b(qr) + index_q = index_q.view(-1, self.n_head, self.head_dim) + with deepseek_v4_profile_scope("indexer_weights_proj"): + weights, _ = self.weights_proj(hidden_states) + with deepseek_v4_profile_scope("indexer_prepare_mxfp4"): + packed_index_q, packed_weights = deepseek_v4_prepare_indexer_q_mxfp4( + index_q=index_q, + positions=positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + ) + + packed_indexer_available = _deepseek_v4_deepgemm_fp4_indexer_available( + packed_index_q[0] + ) + fallback_index_q = index_q.new_empty((0, self.n_head, self.head_dim)) + fallback_weights = weights.new_empty((0, self.n_head)) + if not packed_indexer_available: + with deepseek_v4_profile_scope("indexer_prepare_reference_fallback"): + fallback_index_q, fallback_weights = ( + deepseek_v4_prepare_indexer_q_reference( + index_q=index_q, + positions=positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + use_fp4=self.use_fp4_cache, + ) + ) + + empty_cpu = torch.empty(0, dtype=torch.int32, device="cpu") + seq_lens_cpu = ( + metadata.seq_lens_cpu[: metadata.num_prefill_reqs] + if metadata.seq_lens_cpu is not None and num_prefill_tokens > 0 + else empty_cpu + ) + query_lens_cpu = ( + metadata.query_lens_cpu[: metadata.num_prefill_reqs] + if metadata.query_lens_cpu is not None and num_prefill_tokens > 0 + else empty_cpu + ) + indexer_block_table = metadata.compressed_block_table( + self.compress_ratio, + indexer_block_size, + ) + prefill_metadata = _deepseek_v4_indexer_prefill_metadata( + metadata=metadata, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + num_prefill_tokens=num_prefill_tokens, + ) + max_prefill_gather_rows = 0 + if prefill_metadata.chunk_plan.numel() > 0: + slot_counts = ( + prefill_metadata.chunk_plan[:, 1] - prefill_metadata.chunk_plan[:, 0] + ) + max_prefill_gather_rows = int(slot_counts.max().item()) + prefill_gather_values, prefill_gather_scales = self._prefill_gather_workspace( + max_prefill_gather_rows, + positions.device, + ) + + decode_schedule_metadata = None + decode_context_lens = None + decode_block_table = None + decode_max_context_len = 0 + if num_decode_tokens > 0: + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_valid_token = ( + metadata.is_valid_token[decode_start:decode_end] + if getattr(metadata, "is_valid_token", None) is not None + else None + ) + decode_plan = _deepseek_v4_indexer_decode_metadata( + positions=decode_positions, + token_to_req_indices=metadata.token_to_req_indices[ + decode_start:decode_end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + is_valid_token=decode_valid_token, + ) + decode_context_lens = decode_plan.context_lens + decode_block_table = decode_plan.block_table + decode_max_context_len = decode_plan.max_context_len + decode_schedule_metadata = _deepseek_v4_indexer_decode_schedule_metadata( + positions=decode_positions, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + context_lens=decode_context_lens, + ) + + topk_out = ( + self.topk_buffer.get(total_tokens, positions.device) + if self.topk_buffer is not None + else torch.empty( + (total_tokens, self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + )[:total_tokens] + return _deepseek_v4_sparse_attn_indexer( + cache_2d=indexer_cache, + positions=positions, + token_to_req_indices=metadata.token_to_req_indices[:total_tokens], + block_table=indexer_block_table, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + prefill_chunk_bounds=prefill_metadata.chunk_bounds, + prefill_chunk_plan=prefill_metadata.chunk_plan, + prefill_slots=prefill_metadata.slots, + prefill_cu_seq_lens=prefill_metadata.cu_seq_lens, + prefill_cu_start=prefill_metadata.cu_start, + prefill_cu_end=prefill_metadata.cu_end, + prefill_row_lens=prefill_metadata.row_lens, + packed_q_values=packed_index_q[0], + packed_q_scales=packed_index_q[1], + packed_weights=packed_weights, + fallback_index_q=fallback_index_q, + fallback_weights=fallback_weights, + decode_schedule_metadata=decode_schedule_metadata, + decode_context_lens=decode_context_lens, + decode_block_table=decode_block_table, + decode_max_context_len=decode_max_context_len, + topk_indices_buffer=topk_out, + prefill_gather_values_workspace=prefill_gather_values, + prefill_gather_scales_workspace=prefill_gather_scales, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_fp4_cache=self.use_fp4_cache, + has_packed_q=packed_indexer_available, + ) def forward( self, @@ -2547,6 +4199,9 @@ def forward( raise RuntimeError("DeepSeek V4 indexer requires forward metadata") indexer_state = pool.get_indexer_state_buffer(layer_index) indexer_state_block_table = metadata.indexer_state_block_table + indexer_state_base_logical_page = getattr( + metadata, "indexer_state_base_logical_page", None + ) if indexer_state_block_table is not None: indexer_state_block_size = pool.get_indexer_state_block_size(layer_index) indexer_state_slot_mapping = _group_slot_mapping_from_raw( @@ -2554,12 +4209,13 @@ def forward( metadata.token_to_req_indices[: positions.numel()], indexer_state_block_table, indexer_state_block_size, - base_offsets=metadata.indexer_state_base_logical_page, + base_offsets=indexer_state_base_logical_page, ) else: indexer_state_block_table = metadata.block_table indexer_state_block_size = pool.state_block_size indexer_state_slot_mapping = out_cache_loc + indexer_state_base_logical_page = None with deepseek_v4_profile_scope("indexer_compressor_total"): self.compressor( hidden_states=hidden_states, @@ -2571,7 +4227,7 @@ def forward( state_cache=indexer_state, state_block_table=indexer_state_block_table, state_block_size=indexer_state_block_size, - state_base_logical_page=metadata.indexer_state_base_logical_page, + state_base_logical_page=indexer_state_base_logical_page, write_compressed_cache=False, ) with deepseek_v4_profile_scope("indexer_compressed_slot_mapping"): @@ -2592,7 +4248,7 @@ def forward( positions=positions, compressor_slot_mapping=indexer_state_slot_mapping, block_table=indexer_state_block_table, - block_table_base_offsets=metadata.indexer_state_base_logical_page, + block_table_base_offsets=indexer_state_base_logical_page, compressor_block_size=indexer_state_block_size, rms_norm_weight=self.compressor.norm.weight, rms_norm_eps=self.compressor.norm.variance_epsilon, @@ -2603,6 +4259,321 @@ def forward( use_fp4_cache=self.use_fp4_cache, compress_ratio=self.compress_ratio, ) + custom_topk = self._forward_sparse_indexer_custom_op( + hidden_states=hidden_states, + qr=qr, + positions=positions, + metadata=metadata, + indexer_cache=pool.get_indexer_kv_buffer_2d(layer_index), + indexer_block_size=indexer_block_size, + cos_sin_cache=cos_sin_cache, + ) + if custom_topk is not None: + return custom_topk + + if ctx.forward_mode is not None and ctx.forward_mode.is_mixed(): + num_prefill_tokens = metadata.num_prefill_tokens + num_decode_tokens = metadata.decode_token_count() + total_tokens = positions.numel() + topk_out = ( + self.topk_buffer.get(total_tokens, positions.device) + if self.topk_buffer is not None + else torch.empty( + (total_tokens, self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + )[:total_tokens] + topk_out.fill_(-1) + + def fill_prefill_topk() -> None: + if num_prefill_tokens <= 0: + return + prefill_positions = positions[:num_prefill_tokens] + + with deepseek_v4_profile_scope("indexer_wq_b_prefill"): + index_q, _ = self.wq_b(qr[:num_prefill_tokens]) + index_q = index_q.view(-1, self.n_head, self.head_dim) + with deepseek_v4_profile_scope("indexer_weights_proj_prefill"): + weights, _ = self.weights_proj(hidden_states[:num_prefill_tokens]) + + packed_index_q = None + packed_weights = None + if self.use_fp4_cache: + with deepseek_v4_profile_scope("indexer_prepare_mxfp4_prefill"): + packed_index_q, packed_weights = ( + deepseek_v4_prepare_indexer_q_mxfp4( + index_q=index_q, + positions=prefill_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + ) + ) + + with deepseek_v4_profile_scope("indexer_prepare_reference_prefill"): + index_q_fallback, weights_fallback = ( + deepseek_v4_prepare_indexer_q_reference( + index_q=index_q, + positions=prefill_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + use_fp4=self.use_fp4_cache, + ) + ) + cache_reader = ( + read_deepseek_v4_indexer_mxfp4_cache + if self.use_fp4_cache + else read_deepseek_v4_indexer_fp8_cache + ) + indexer_cache = pool.get_indexer_kv_buffer_2d(layer_index) + seq_lens_cpu = ( + metadata.seq_lens_cpu[: metadata.num_prefill_reqs] + if metadata.seq_lens_cpu is not None + else None + ) + query_lens_cpu = ( + metadata.query_lens_cpu[: metadata.num_prefill_reqs] + if metadata.query_lens_cpu is not None + else None + ) + request_chunks = ( + _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + compress_ratio=self.compress_ratio, + num_tokens=num_prefill_tokens, + ) + if seq_lens_cpu is not None and query_lens_cpu is not None + else [] + ) + if request_chunks: + gather_cache_key = None + gathered_k = None + for chunk in request_chunks: + topk = None + if packed_index_q is not None and packed_weights is not None: + with deepseek_v4_profile_scope( + "indexer_topk_deepgemm_prefill" + ): + gather_plan = ( + _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + query_start=chunk.query_start, + query_end=chunk.query_end, + ) + ) + key = (chunk.req_start, chunk.req_end) + reuse_k = ( + gathered_k + if chunk.skip_kv_gather and gather_cache_key == key + else None + ) + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + cache_2d=indexer_cache, + gather_plan=gather_plan, + cache_block_size=indexer_block_size, + index_q=( + packed_index_q[0][ + chunk.token_start : chunk.token_end + ], + packed_index_q[1][ + chunk.token_start : chunk.token_end + ], + ), + weights=packed_weights[ + chunk.token_start : chunk.token_end + ], + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + ) + ) + if topk is not None and next_gathered_k is not None: + gather_cache_key = key + gathered_k = next_gathered_k + if topk is None: + with deepseek_v4_profile_scope( + "indexer_topk_fallback_prefill" + ): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=indexer_cache, + positions=prefill_positions[ + chunk.token_start : chunk.token_end + ], + token_to_req_indices=metadata.token_to_req_indices[ + chunk.token_start : chunk.token_end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback[ + chunk.token_start : chunk.token_end + ], + weights=weights_fallback[ + chunk.token_start : chunk.token_end + ], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + topk_out[chunk.token_start : chunk.token_end].copy_(topk) + return + + topk_chunks = [] + for start, end in _deepseek_v4_indexer_prefill_topk_chunks( + prefill_positions, + self.compress_ratio, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + ): + if packed_index_q is not None and packed_weights is not None: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + topk = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( + cache_2d=indexer_cache, + positions=prefill_positions[start:end], + token_to_req_indices=metadata.token_to_req_indices[ + start:end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=( + packed_index_q[0][start:end], + packed_index_q[1][start:end], + ), + weights=packed_weights[start:end], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + ) + if topk is not None: + topk_chunks.append(topk) + continue + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk_chunks.append( + _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=indexer_cache, + positions=prefill_positions[start:end], + token_to_req_indices=metadata.token_to_req_indices[ + start:end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback[start:end], + weights=weights_fallback[start:end], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + ) + if topk_chunks: + with deepseek_v4_profile_scope("indexer_topk_cat_prefill"): + topk_out[:num_prefill_tokens].copy_( + torch.cat(topk_chunks, dim=0) + ) + + def fill_decode_topk() -> None: + if num_decode_tokens <= 0: + return + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_token_to_req = metadata.token_to_req_indices[ + decode_start:decode_end + ] + decode_valid_token = ( + metadata.is_valid_token[decode_start:decode_end] + if getattr(metadata, "is_valid_token", None) is not None + else None + ) + decode_out = topk_out[decode_start:decode_end] + with deepseek_v4_profile_scope("indexer_wq_b_decode"): + index_q, _ = self.wq_b(qr[decode_start:decode_end]) + index_q = index_q.view(-1, self.n_head, self.head_dim) + with deepseek_v4_profile_scope("indexer_weights_proj_decode"): + weights, _ = self.weights_proj( + hidden_states[decode_start:decode_end] + ) + + packed_index_q = None + packed_weights = None + if self.use_fp4_cache: + with deepseek_v4_profile_scope("indexer_prepare_mxfp4_decode"): + packed_index_q, packed_weights = ( + deepseek_v4_prepare_indexer_q_mxfp4( + index_q=index_q, + positions=decode_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + ) + ) + with deepseek_v4_profile_scope("indexer_topk_deepgemm_decode"): + topk = _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( + cache_2d=pool.get_indexer_kv_buffer_2d(layer_index), + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=packed_index_q, + weights=packed_weights, + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + metadata=metadata, + is_valid_token=decode_valid_token, + out=decode_out, + ) + if topk is not None: + return + + with deepseek_v4_profile_scope("indexer_prepare_reference_decode"): + index_q_fallback, weights_fallback = ( + deepseek_v4_prepare_indexer_q_reference( + index_q=index_q, + positions=decode_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + use_fp4=self.use_fp4_cache, + ) + ) + cache_reader = ( + read_deepseek_v4_indexer_mxfp4_cache + if self.use_fp4_cache + else read_deepseek_v4_indexer_fp8_cache + ) + _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=pool.get_indexer_kv_buffer_2d(layer_index), + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback, + weights=weights_fallback, + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + out=decode_out, + ) + + fill_prefill_topk() + fill_decode_topk() + return topk_out with deepseek_v4_profile_scope("indexer_wq_b"): index_q, _ = self.wq_b(qr) index_q = index_q.view(-1, self.n_head, self.head_dim) @@ -2638,6 +4609,11 @@ def forward( compress_ratio=self.compress_ratio, topk_tokens=self.topk_tokens, metadata=metadata, + is_valid_token=( + metadata.is_valid_token[: positions.numel()] + if getattr(metadata, "is_valid_token", None) is not None + else None + ), out=topk_out, ) if topk is not None: @@ -2681,10 +4657,104 @@ def forward( ) indexer_cache = pool.get_indexer_kv_buffer_2d(layer_index) + request_chunks = ( + _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=metadata.seq_lens_cpu, + query_lens_cpu=metadata.query_lens_cpu, + compress_ratio=self.compress_ratio, + num_tokens=positions.numel(), + ) + if metadata.seq_lens_cpu is not None and metadata.query_lens_cpu is not None + else [] + ) + if request_chunks: + topk_out = ( + self.topk_buffer.get(positions.numel(), positions.device) + if self.topk_buffer is not None + else torch.empty( + (positions.numel(), self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + )[: positions.numel()] + topk_out.fill_(-1) + gather_cache_key = None + gathered_k = None + for chunk in request_chunks: + topk = None + if packed_index_q is not None and packed_weights is not None: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + gather_plan = _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=metadata.seq_lens_cpu, + query_lens_cpu=metadata.query_lens_cpu, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + query_start=chunk.query_start, + query_end=chunk.query_end, + ) + key = (chunk.req_start, chunk.req_end) + reuse_k = ( + gathered_k + if chunk.skip_kv_gather and gather_cache_key == key + else None + ) + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + cache_2d=indexer_cache, + gather_plan=gather_plan, + cache_block_size=indexer_block_size, + index_q=( + packed_index_q[0][ + chunk.token_start : chunk.token_end + ], + packed_index_q[1][ + chunk.token_start : chunk.token_end + ], + ), + weights=packed_weights[ + chunk.token_start : chunk.token_end + ], + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + ) + ) + if topk is not None and next_gathered_k is not None: + gather_cache_key = key + gathered_k = next_gathered_k + if topk is None: + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=indexer_cache, + positions=positions[chunk.token_start : chunk.token_end], + token_to_req_indices=metadata.token_to_req_indices[ + chunk.token_start : chunk.token_end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback[ + chunk.token_start : chunk.token_end + ], + weights=weights_fallback[ + chunk.token_start : chunk.token_end + ], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + topk_out[chunk.token_start : chunk.token_end].copy_(topk) + return topk_out + topk_chunks = [] for start, end in _deepseek_v4_indexer_prefill_topk_chunks( positions, self.compress_ratio, + seq_lens_cpu=metadata.seq_lens_cpu, + query_lens_cpu=metadata.query_lens_cpu, ): if packed_index_q is not None and packed_weights is not None: with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): @@ -3342,6 +5412,14 @@ def run_compressor() -> None: topk_indices = None if self.indexer is not None: assert self.compressor is not None + with deepseek_v4_profile_scope( + f"{profile_prefix}_indexer_prepare_decode_metadata" + ): + self.indexer.prepare_decode_metadata( + positions=positions, + metadata=metadata, + indexer_block_size=pool.get_indexer_block_size(self.layer_index), + ) def run_indexer() -> torch.Tensor: with deepseek_v4_profile_scope(f"{profile_prefix}_indexer"): @@ -3381,12 +5459,38 @@ def insert_and_compress() -> None: "forward_deepseek_v4_decode", None, ) + backend_mixed = getattr( + ctx.attn_backend, + "forward_deepseek_v4_mixed", + None, + ) backend_prefill = getattr( ctx.attn_backend, "forward_deepseek_v4_prefill", None, ) if ( + backend_mixed is not None + and ctx.forward_mode is not None + and ctx.forward_mode.is_mixed() + ): + with deepseek_v4_profile_scope(f"{profile_prefix}_mixed_backend"): + attn_output = backend_mixed( + q=q, + positions=positions, + token_to_kv_pool=pool, + layer_id=self.layer_index, + kind=self.attention_kind, + compress_ratio=self.compress_ratio, + num_local_heads=self.num_local_heads, + padded_heads=self.padded_heads, + head_dim=self.head_dim, + window_size=self.layout.swa_window, + softmax_scale=self.scale, + attn_sink=self.attn_sink, + topk_indices=topk_indices, + ) + elif ( backend_decode is not None and ctx.forward_mode is not None and ctx.forward_mode.is_decode() @@ -3534,6 +5638,9 @@ def _pre_mlp_input_ids_comm( [tokens[:count] for tokens, count in zip(gathered, token_counts)], dim=0 ) + def _mega_moe_token_counts(self, ctx: ForwardContext) -> list[int]: + return self.comm_manager.moe_tp_ep_group_scattered_num_tokens(ctx) + def forward( self, positions: torch.Tensor, @@ -3564,11 +5671,7 @@ def forward( ffn_input_ids = input_ids use_mega_moe = getattr(self.ffn, "use_mega_moe", False) if use_mega_moe: - token_counts = ( - [int(count) for count in ctx.global_num_tokens] - if ctx.global_num_tokens is not None - else [int(hidden_states.shape[0])] - ) + token_counts = self._mega_moe_token_counts(ctx) num_global_tokens = sum(token_counts) max_num_tokens_per_gpu = max(token_counts) if token_counts else 0 else: diff --git a/python/tokenspeed/runtime/utils/custom_ops.py b/python/tokenspeed/runtime/utils/custom_ops.py new file mode 100644 index 000000000..1a9fb2d65 --- /dev/null +++ b/python/tokenspeed/runtime/utils/custom_ops.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import importlib +from collections.abc import Callable + +import torch +import torch.library +from torch.library import Library + +tokenspeed_lib = Library("tokenspeed", "FRAGMENT") + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str], + fake_impl: Callable | None = None, + target_lib: Library | None = None, +) -> None: + """Register a low-overhead torch custom op in the TokenSpeed namespace.""" + + target = target_lib or tokenspeed_lib + lib_name = getattr(getattr(target, "m", None), "name", "tokenspeed") + try: + if hasattr(torch.ops, lib_name) and hasattr( + getattr(torch.ops, lib_name), op_name + ): + return + except (AttributeError, RuntimeError): + pass + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + custom_op_impl = importlib.import_module("torch._custom_op.impl") + schema_str = custom_op_impl.infer_schema(op_func, mutates_args) + + target.define(op_name + schema_str) + target.impl(op_name, op_func, "CUDA") + if fake_impl is not None: + target._register_fake(op_name, fake_impl) diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index 3bf737a16..310b864c0 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -88,6 +88,7 @@ class ServerArgs: max_total_tokens: int | None = None chunked_prefill_size: int | None = None max_prefill_tokens: int = 8192 + enable_mixed_batch: bool = False block_size: int = 64 # special kv cache mamba_ssm_dtype: str = "float32" @@ -817,6 +818,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.chunked_prefill_size, help="Maximum number of tokens the scheduler may issue in a single iteration. Setting this to -1 disables chunked prefill.", ) + parser.add_argument( + "--enable-mixed-batch", + action="store_true", + dest="enable_mixed_batch", + default=ServerArgs.enable_mixed_batch, + help="Allow the scheduler to issue prefill and decode requests in the same iteration.", + ) parser.add_argument( "--block-size", metavar="BLOCK_SIZE", diff --git a/test/runtime/kernels/test_trtllm_wrapper.py b/test/runtime/kernels/test_trtllm_wrapper.py new file mode 100644 index 000000000..c458a9047 --- /dev/null +++ b/test/runtime/kernels/test_trtllm_wrapper.py @@ -0,0 +1,103 @@ +import unittest +from unittest.mock import patch + +import torch + + +class TRTLLMWrapperTest(unittest.TestCase): + def test_fast_topk_v2_decode_accepts_2d_lens(self): + from tokenspeed_kernel.registry import error_fn + from tokenspeed_kernel.thirdparty import trtllm + + if trtllm.fast_topk_v2 is None or trtllm.fast_topk_v2 is error_fn: + self.skipTest("TRTLLM fast_topk_v2 is unavailable on this platform") + + captured = {} + + def fake_indexer_topk_decode(values, seq_lens, indices, next_n, topk): + del values, indices + captured["seq_lens"] = seq_lens + captured["next_n"] = next_n + captured["topk"] = topk + + with patch.object( + torch.ops.trtllm, + "indexer_topk_decode", + fake_indexer_topk_decode, + create=True, + ): + values = torch.empty((2, 4), dtype=torch.float32) + seq_lens = torch.tensor([[3], [4]], dtype=torch.int64) + indices = torch.empty((2, 2), dtype=torch.int32) + + trtllm.fast_topk_v2( + values, + seq_lens, + indices, + topk=2, + next_n=1, + ) + + self.assertEqual(captured["next_n"], 1) + self.assertEqual(captured["topk"], 2) + self.assertEqual(captured["seq_lens"].dtype, torch.int32) + self.assertEqual(captured["seq_lens"].dim(), 1) + torch.testing.assert_close( + captured["seq_lens"], + torch.tensor([3, 4], dtype=torch.int32), + atol=0, + rtol=0, + ) + + def test_fast_topk_v2_prefill_uses_int32_row_offsets(self): + from tokenspeed_kernel.registry import error_fn + from tokenspeed_kernel.thirdparty import trtllm + + if trtllm.fast_topk_v2 is None or trtllm.fast_topk_v2 is error_fn: + self.skipTest("TRTLLM fast_topk_v2 is unavailable on this platform") + + captured = {} + + def fake_indexer_topk_prefill(values, row_starts, row_ends, indices, topk): + del values, indices + captured["row_starts"] = row_starts + captured["row_ends"] = row_ends + captured["topk"] = topk + + with patch.object( + torch.ops.trtllm, + "indexer_topk_prefill", + fake_indexer_topk_prefill, + create=True, + ): + values = torch.empty((3, 4), dtype=torch.float32) + seq_lens = torch.tensor([[1], [2]], dtype=torch.int64) + indices = torch.empty((2, 2), dtype=torch.int32) + + trtllm.fast_topk_v2( + values, + seq_lens, + indices, + topk=2, + next_n=2, + ) + + self.assertEqual(captured["topk"], 2) + self.assertEqual(captured["row_starts"].dtype, torch.int32) + self.assertEqual(captured["row_ends"].dtype, torch.int32) + torch.testing.assert_close( + captured["row_starts"], + torch.tensor([0, 1], dtype=torch.int32), + atol=0, + rtol=0, + ) + torch.testing.assert_close( + captured["row_ends"], + torch.tensor([1, 3], dtype=torch.int32), + atol=0, + rtol=0, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/runtime/test_cli_config_compat.py b/test/runtime/test_cli_config_compat.py index 3cc53e480..b690c42fb 100644 --- a/test/runtime/test_cli_config_compat.py +++ b/test/runtime/test_cli_config_compat.py @@ -193,6 +193,7 @@ def test_prefill_token_defaults(self): args = self._parse_args(["--model", "test/model"]) self.assertEqual(args.max_prefill_tokens, 8192) self.assertIsNone(args.chunked_prefill_size) + self.assertFalse(args.enable_mixed_batch) sa = self._from_cli_args_no_init(args) sa.mapping = SimpleNamespace(world_size=1) @@ -205,6 +206,11 @@ def test_prefill_token_defaults(self): self.assertEqual(sa.max_prefill_tokens, 8192) self.assertEqual(sa.chunked_prefill_size, 8192) + self.assertFalse(sa.enable_mixed_batch) + + def test_mixed_batch_can_be_enabled(self): + args = self._parse_args(["--model", "test/model", "--enable-mixed-batch"]) + self.assertTrue(args.enable_mixed_batch) def test_distributed_timeout_seconds_arg(self): args = self._parse_args( diff --git a/test/runtime/test_deepseek_v4_attention_ops.py b/test/runtime/test_deepseek_v4_attention_ops.py index 5c92d257c..34549e094 100644 --- a/test/runtime/test_deepseek_v4_attention_ops.py +++ b/test/runtime/test_deepseek_v4_attention_ops.py @@ -628,6 +628,91 @@ def test_indexer_mxfp4_cache_matches_reference(self): ) self.assertEqual(int(flat_cache[64:128].sum()), 0) + def test_indexer_mxfp4_paged_gather_matches_paged_layout(self): + from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( + has_indexer_mxfp4_paged_gather, + indexer_mxfp4_paged_gather, + ) + + if not has_indexer_mxfp4_paged_gather(): + self.skipTest("DeepSeek V4 paged MXFP4 gather op is not available") + + device = torch.device("cuda") + block_size = 4 + value_bytes = 64 + scale_bytes = 4 + num_blocks = 3 + kv_cache = torch.zeros( + num_blocks, + block_size * (value_bytes + scale_bytes), + device=device, + dtype=torch.uint8, + ) + + value_rows = {} + scale_rows = {} + for block_idx in range(num_blocks): + for row_idx in range(block_size): + values = ( + ( + torch.arange(value_bytes, device=device, dtype=torch.int16) + + block_idx * 37 + + row_idx * 11 + ) + .remainder(251) + .to(torch.uint8) + ) + scales = torch.tensor( + [block_idx, row_idx, block_idx * 17 + row_idx, 200 + block_idx], + device=device, + dtype=torch.uint8, + ) + value_base = row_idx * value_bytes + scale_base = block_size * value_bytes + row_idx * scale_bytes + kv_cache[block_idx, value_base : value_base + value_bytes].copy_(values) + kv_cache[block_idx, scale_base : scale_base + scale_bytes].copy_(scales) + value_rows[(block_idx, row_idx)] = values + scale_rows[(block_idx, row_idx)] = scales + + block_table = torch.tensor([[2, 0], [1, 0]], device=device, dtype=torch.int32) + cu_seq_lens = torch.tensor([0, 5, 7], device=device, dtype=torch.int32) + values_out = torch.full( + (8, value_bytes), 0xCC, device=device, dtype=torch.uint8 + ) + scales_out = torch.full( + (8, scale_bytes), 0xDD, device=device, dtype=torch.uint8 + ) + + indexer_mxfp4_paged_gather( + kv_cache, + values_out, + scales_out, + block_table, + cu_seq_lens, + block_size, + ) + torch.cuda.synchronize() + + expected_plan = [ + (2, 0), + (2, 1), + (2, 2), + (2, 3), + (0, 0), + (1, 0), + (1, 1), + ] + expected_values = torch.stack([value_rows[item] for item in expected_plan]) + expected_scales = torch.stack([scale_rows[item] for item in expected_plan]) + self.assertTrue(torch.equal(values_out[:7].cpu(), expected_values.cpu())) + self.assertTrue(torch.equal(scales_out[:7].cpu(), expected_scales.cpu())) + self.assertTrue( + torch.equal(values_out[7].cpu(), torch.full((64,), 0xCC, dtype=torch.uint8)) + ) + self.assertTrue( + torch.equal(scales_out[7].cpu(), torch.full((4,), 0xDD, dtype=torch.uint8)) + ) + def test_csa_indexer_cache_insert_matches_reference(self): torch.manual_seed(8901) device = torch.device("cuda") @@ -1082,6 +1167,46 @@ def test_decode_swa_indices_and_lens_matches_reference(self): compact_lens.cpu(), actual_lens.cpu(), atol=0, rtol=0 ) + def test_decode_swa_indices_and_lens_masks_invalid_tokens(self): + device = torch.device("cuda") + query_start_loc = torch.tensor([0, 1, 2], device=device, dtype=torch.int32) + seq_lens = torch.tensor([70, 3], device=device, dtype=torch.int32) + token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32) + is_valid_token = torch.tensor([True, False], device=device) + block_table = torch.tensor( + [[10, 11], [20, 21]], + device=device, + dtype=torch.int32, + ) + out_indices = torch.full((2, 4), -123, device=device, dtype=torch.int32) + out_lens = torch.empty((2,), device=device, dtype=torch.int32) + + actual, actual_lens = deepseek_v4_decode_swa_indices_and_lens( + query_start_loc=query_start_loc, + seq_lens=seq_lens, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + window_size=4, + block_size=64, + is_valid_token=is_valid_token, + out_indices=out_indices, + out_lens=out_lens, + ) + torch.cuda.synchronize() + + self.assertTrue( + torch.equal(actual_lens.cpu(), torch.tensor([4, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + actual[0].cpu(), + torch.tensor([706, 707, 708, 709], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal(actual[1].cpu(), torch.full((4,), -123, dtype=torch.int32)) + ) + def test_compute_global_topk_indices_and_lens_matches_reference(self): device = torch.device("cuda") topk_indices = torch.tensor( @@ -1134,6 +1259,46 @@ def test_compute_global_topk_indices_and_lens_matches_reference(self): actual_lens.cpu(), expected_lens.cpu(), atol=0, rtol=0 ) + def test_compute_global_topk_indices_and_lens_masks_invalid_tokens(self): + device = torch.device("cuda") + topk_indices = torch.tensor( + [ + [0, 1, -1, 5], + [3, -1, -1, -1], + ], + device=device, + dtype=torch.int32, + ) + token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32) + is_valid_token = torch.tensor([True, False], device=device) + block_table = torch.tensor( + [ + [10, 11], + [20, 21], + ], + device=device, + dtype=torch.int32, + ) + + actual, actual_lens = deepseek_v4_compute_global_topk_indices_and_lens( + topk_indices=topk_indices, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + block_size=4, + is_valid_token=is_valid_token, + ) + torch.cuda.synchronize() + + self.assertTrue( + torch.equal(actual_lens.cpu(), torch.tensor([3, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + actual[0].cpu(), + torch.tensor([40, 41, -1, 45], dtype=torch.int32), + ) + ) + def test_compressed_slot_mapping_matches_page_reference(self): device = torch.device("cuda") query_start_loc = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) diff --git a/test/runtime/test_deepseek_v4_config.py b/test/runtime/test_deepseek_v4_config.py index 3bf2c1715..c539a623e 100644 --- a/test/runtime/test_deepseek_v4_config.py +++ b/test/runtime/test_deepseek_v4_config.py @@ -12,12 +12,17 @@ configure_deepseek_v4_attention, is_deepseek_v4, ) +from tokenspeed.runtime.execution.cuda_graph_wrapper import CudaGraphWrapper from tokenspeed.runtime.execution.forward_batch_info import ForwardMode +from tokenspeed.runtime.layers.attention.backends import ( + deepseek_v4 as deepseek_v4_backend, +) from tokenspeed.runtime.layers.attention.backends.deepseek_v4 import ( DeepseekV4AttentionBackend, ) from tokenspeed.runtime.layers.attention.deepseek_v4_ops import ( DeepseekV4AttentionOpUnavailable, + deepseek_v4_compute_global_topk_indices_and_lens, deepseek_v4_indexer_topk_reference, fused_qnorm_rope_kv_insert, has_fused_qnorm_rope_kv_insert, @@ -33,21 +38,26 @@ _get_flashinfer_mxfp4_device_permute_indices, _reorder_w1w3_to_w3w1, ) -from tokenspeed.runtime.layers.moe.backends.mxfp4.triton_kernel import ( - _mxfp4_scale_for_layout, -) -from tokenspeed.runtime.layers.moe.backends.mxfp4.weights import MXFP4_SCALE_DTYPE from tokenspeed.runtime.layers.quantization import QUANTIZATION_METHODS +from tokenspeed.runtime.models import deepseek_v4 as deepseek_v4_model from tokenspeed.runtime.models.deepseek_v4 import ( DeepseekV4Attention, + DeepseekV4Indexer, + DeepseekV4MLP, DeepseekV4MoEGate, _deepseek_v4_fused_select_experts, _deepseek_v4_gather_indexer_mxfp4_cache, _deepseek_v4_get_fp8_linear_deep_gemm, _deepseek_v4_indexer_decode_max_len, + _deepseek_v4_indexer_prefill_gather_plan, + _deepseek_v4_indexer_prefill_max_logits_bytes, + _deepseek_v4_indexer_prefill_metadata, + _deepseek_v4_indexer_prefill_request_chunks, + _deepseek_v4_indexer_prefill_request_gather_plan, _deepseek_v4_indexer_prefill_topk_chunks, _deepseek_v4_indexer_topk_from_cache_batched, _deepseek_v4_indexer_topk_from_logits, + _deepseek_v4_prefill_topk_op_available, _deepseek_v4_reorder_c4_ape_2604, _DeepseekV4TopKBuffer, _fp8_act_quant_dequant, @@ -80,6 +90,62 @@ def test_config_registry(self): self.assertEqual(DeepseekV4Config.model_type, "deepseek_v4") self.assertIs(_CONFIG_REGISTRY["deepseek_v4"], DeepseekV4Config) + def test_forward_mode_mixed_predicate(self): + self.assertTrue(ForwardMode.MIXED.is_mixed()) + self.assertFalse(ForwardMode.EXTEND.is_mixed()) + self.assertFalse(ForwardMode.DECODE.is_mixed()) + self.assertEqual(ForwardMode.from_num_extends(0, 0), ForwardMode.IDLE) + self.assertEqual(ForwardMode.from_num_extends(0, 2), ForwardMode.DECODE) + self.assertEqual( + ForwardMode.from_num_extends(0, 2, has_drafter=True), + ForwardMode.TARGET_VERIFY, + ) + self.assertEqual(ForwardMode.from_num_extends(2, 2), ForwardMode.EXTEND) + self.assertEqual(ForwardMode.from_num_extends(1, 2), ForwardMode.MIXED) + + def test_cuda_graph_group_table_padding_uses_dummy_page_rows(self): + table = torch.tensor([[5, -1]], dtype=torch.int32) + padded = CudaGraphWrapper._pad_block_tables_to_padded_bs( + {"v4.swa": table}, + actual_bs=1, + padded_bs=3, + ) + + self.assertEqual(padded["v4.swa"].tolist(), [[5, -1], [0, 0], [0, 0]]) + + def test_cuda_graph_replay_keeps_idle_actual_bs_with_padded_group_tables(self): + captured = {} + + class FakeBackend: + uses_paged_cache_groups = True + uses_padded_decode_token_mask = True + + def init_forward_metadata_replay_cuda_graph(self, *args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + wrapper = object.__new__(CudaGraphWrapper) + wrapper.attn_backend = FakeBackend() + wrapper.draft_attn_backend = None + + wrapper._init_replay_metadata( + padded_bs=4, + actual_bs=0, + req_pool_indices=torch.zeros(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + req_to_page=torch.zeros((1, 1), dtype=torch.int32), + forward_mode=ForwardMode.DECODE, + paged_cache_block_tables={ + "v4.swa": torch.zeros((4, 1), dtype=torch.int32), + }, + ) + + self.assertEqual(captured["kwargs"]["actual_bs"], 0) + self.assertEqual( + captured["kwargs"]["paged_cache_block_tables"]["v4.swa"].shape, + (4, 1), + ) + def test_deepseek_v4_tokenizer_wrapper_uses_model_encoder(self): calls = [] @@ -219,6 +285,19 @@ def test_deepseek_v4_server_args_cli_flags_round_trip(self): global_server_args_dict.clear() global_server_args_dict.update(snapshot) + def test_deepseek_v4_indexer_prefill_max_logits_uses_server_arg(self): + snapshot = dict(global_server_args_dict) + try: + global_server_args_dict["deepseek_v4_indexer_prefill_max_logits_mb"] = 7 + + self.assertEqual( + _deepseek_v4_indexer_prefill_max_logits_bytes(), + 7 * 1024 * 1024, + ) + finally: + global_server_args_dict.clear() + global_server_args_dict.update(snapshot) + def test_fp8_quantization_config(self): quantization = QUANTIZATION_METHODS["fp8"] @@ -426,6 +505,27 @@ def test_deepseek_v4_attention_op_boundary_fails_loudly_when_missing(self): q, kv, cache, slots, positions, cos_sin, 1e-6, 256 ) + def test_deepseek_v4_flashmla_wrapper_exposes_required_api(self): + try: + from tokenspeed_kernel.ops.attention.flash_mla import ( + flash_mla_sparse_fwd, + flash_mla_with_kvcache, + get_mla_metadata, + ) + from tokenspeed_kernel.registry import error_fn + except Exception as exc: + self.skipTest(f"FlashMLA wrapper unavailable: {exc}") + if ( + flash_mla_with_kvcache is error_fn + or flash_mla_sparse_fwd is error_fn + or get_mla_metadata is error_fn + ): + self.skipTest("FlashMLA wrapper unavailable on this platform") + + self.assertTrue(callable(flash_mla_with_kvcache)) + self.assertTrue(callable(flash_mla_sparse_fwd)) + self.assertTrue(callable(get_mla_metadata)) + def test_deepseek_v4_model_config_uses_mla_runtime_metadata(self): model_config = object.__new__(ModelConfig) model_config.hf_config = SimpleNamespace( @@ -539,7 +639,7 @@ def test_deepseek_v4_kv_pool_allocates_v4_cache_families(self): use_fp4_indexer_cache=True, ) - self.assertEqual(layout.cache_cell_size(3), 17329) + self.assertEqual(layout.cache_cell_size(3), 16771) pool = DeepseekV4TokenToKVPool( size=128, @@ -675,6 +775,45 @@ def test_deepseek_v4_backend_preserves_compact_paged_cache_contract(self): self.assertTrue(torch.equal(metadata.swa_block_table, compact)) self.assertTrue(torch.equal(metadata.swa_base_logical_page, base)) + def test_deepseek_v4_mixed_metadata_keeps_decode_rows_single_token(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=4096, + ) + ) + + backend.init_forward_metadata( + bs=3, + num_tokens=10, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int64), + seq_lens=torch.tensor([7, 10, 4], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.zeros((3, 1), dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([7], dtype=torch.int32), + num_extends=1, + ) + + metadata = backend.forward_metadata + self.assertIsNotNone(metadata) + assert metadata is not None + self.assertEqual(metadata.query_lens.tolist(), [7, 1, 1]) + self.assertEqual(metadata.query_lens_cpu.tolist(), [7, 1, 1]) + self.assertEqual(metadata.num_prefill_reqs, 1) + self.assertEqual(metadata.num_prefill_tokens, 7) + self.assertEqual(metadata.decode_req_count(), 2) + self.assertEqual(metadata.decode_token_count(), 2) + self.assertEqual( + metadata.token_to_req_indices.tolist(), + [0, 0, 0, 0, 0, 0, 0, 1, 2], + ) + def test_deepseek_v4_cuda_graph_refresh_keeps_compact_table_columns(self): backend = DeepseekV4AttentionBackend( SimpleNamespace( @@ -918,7 +1057,7 @@ def test_deepseek_v4_metadata_maps_compressed_slots(self): torch.tensor([3, 7, 127], dtype=torch.int64), compress_ratio=4, ) - self.assertTrue(torch.equal(slots, torch.tensor([0, 1, 31]))) + self.assertTrue(torch.equal(slots, torch.tensor([640, 641, 671]))) page256_metadata = DeepseekV4ForwardMetadata( page_size=256, @@ -936,6 +1075,271 @@ def test_deepseek_v4_metadata_maps_compressed_slots(self): ) self.assertTrue(torch.equal(slots, torch.tensor([383, 384, 447]))) + grouped_metadata = DeepseekV4ForwardMetadata( + page_size=256, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[5, 6], [7, 8]], dtype=torch.int32), + seq_lens=torch.tensor([300, 10], dtype=torch.int32), + query_lens=torch.tensor([3, 2], dtype=torch.int32), + query_start_loc=torch.tensor([0, 3, 5], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32), + paged_cache_block_tables={ + "v4.c4a.compressed_kv": torch.tensor( + [[20, 21], [30, -1]], dtype=torch.int32 + ) + }, + ) + slots = grouped_metadata.compressed_slot_mapping( + torch.tensor([255, 256, 511, 2560, 4], dtype=torch.int64), + compress_ratio=4, + kv_cache_block_size=64, + ) + self.assertTrue(torch.equal(slots, torch.tensor([1343, 1344, 1407, -1, 1921]))) + + def test_deepseek_v4_group_slot_mapping_from_raw(self): + block_table = torch.tensor([[10, 11], [20, -1]], dtype=torch.int32) + slots = _group_slot_mapping_from_raw( + positions=torch.tensor([0, 63, 64, 9, 10], dtype=torch.int64), + req_indices=torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32), + block_table=block_table, + rows_per_page=64, + entry_stride_tokens=1, + ) + self.assertTrue(torch.equal(slots, torch.tensor([640, 703, 704, 1289, 1290]))) + + compressed_slots = _group_slot_mapping_from_raw( + positions=torch.tensor([0, 255, 256, 511], dtype=torch.int64), + req_indices=torch.tensor([0, 0, 0, 1], dtype=torch.int32), + block_table=block_table, + rows_per_page=64, + entry_stride_tokens=4, + ) + self.assertTrue( + torch.equal(compressed_slots, torch.tensor([640, 703, 704, -1])) + ) + + def test_deepseek_v4_mixed_metadata_splits_prefill_and_decode(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + backend.init_forward_metadata( + bs=3, + num_tokens=5, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 1, 1], dtype=torch.int32), + extend_prefix_lens_cpu=torch.tensor([2, 8, 11], dtype=torch.int32), + num_extends=1, + ) + metadata = backend.forward_metadata + self.assertIsNotNone(metadata) + self.assertEqual(metadata.num_prefill_reqs, 1) + self.assertEqual(metadata.num_prefill_tokens, 3) + self.assertEqual(metadata.decode_req_count(), 2) + self.assertEqual(metadata.decode_token_count(), 2) + self.assertTrue( + torch.equal( + metadata.token_to_req_indices, + torch.tensor([0, 0, 0, 1, 2], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.seq_lens_cpu, + torch.tensor([5, 9, 12], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.query_lens_cpu, + torch.tensor([3, 1, 1], dtype=torch.int32), + ) + ) + + prefill = backend._metadata_slice( + metadata, + req_start=0, + req_end=1, + token_start=0, + token_end=3, + forward_mode=ForwardMode.EXTEND, + ) + decode = backend._metadata_slice( + metadata, + req_start=1, + req_end=3, + token_start=3, + token_end=5, + forward_mode=ForwardMode.DECODE, + ) + + self.assertTrue(prefill.forward_mode.is_extend()) + self.assertTrue(decode.forward_mode.is_decode()) + self.assertTrue( + torch.equal(prefill.token_to_req_indices, torch.tensor([0, 0, 0])) + ) + self.assertTrue(torch.equal(decode.token_to_req_indices, torch.tensor([0, 1]))) + self.assertTrue( + torch.equal( + decode.query_start_loc, torch.tensor([0, 1, 2], dtype=torch.int32) + ) + ) + self.assertTrue(torch.equal(decode.block_table[:, 0], torch.tensor([20, 30]))) + self.assertTrue( + torch.equal(prefill.seq_lens_cpu, torch.tensor([5], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(decode.query_lens_cpu, torch.tensor([1, 1], dtype=torch.int32)) + ) + + def test_deepseek_v4_mixed_metadata_accepts_prefill_prefix_lens_only(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + backend.init_forward_metadata( + bs=4, + num_tokens=8, + req_pool_indices=torch.tensor([0, 1, 2, 3], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12, 6], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30], [40]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 4, 1, 1], dtype=torch.int32), + extend_prefix_lens_cpu=torch.tensor([2, 5, 11], dtype=torch.int32), + num_extends=3, + ) + + metadata = backend.forward_metadata + self.assertIsNotNone(metadata) + self.assertEqual(metadata.num_prefill_reqs, 3) + self.assertEqual(metadata.num_prefill_tokens, 8) + self.assertEqual(metadata.decode_req_count(), 1) + self.assertEqual(metadata.decode_token_count(), 1) + self.assertTrue( + torch.equal( + metadata.seq_lens_cpu, + torch.tensor([5, 9, 12, 6], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.query_lens_cpu, + torch.tensor([3, 4, 1, 1], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_mixed_backend_slices_prefill_and_decode(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + backend.init_forward_metadata( + bs=3, + num_tokens=5, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 1, 1], dtype=torch.int32), + num_extends=1, + ) + calls = [] + + def fake_prefill(**kwargs): + metadata = backend.forward_metadata + calls.append( + ( + "prefill", + kwargs["q"].shape[0], + kwargs["positions"].tolist(), + kwargs["topk_indices"].tolist(), + metadata.req_pool_indices.tolist(), + metadata.token_to_req_indices.tolist(), + metadata.forward_mode, + ) + ) + return kwargs["q"].new_full((3, 2, 4), 1.0) + + def fake_decode(**kwargs): + metadata = backend.forward_metadata + calls.append( + ( + "decode", + kwargs["q"].shape[0], + kwargs["positions"].tolist(), + kwargs["topk_indices"].tolist(), + metadata.req_pool_indices.tolist(), + metadata.token_to_req_indices.tolist(), + metadata.forward_mode, + ) + ) + return kwargs["q"].new_full((2, 2, 4), 2.0) + + backend.forward_deepseek_v4_prefill = fake_prefill + backend.forward_deepseek_v4_decode = fake_decode + q = torch.zeros((5, 2, 4), dtype=torch.float32) + topk = torch.arange(10, dtype=torch.int32).view(5, 2) + out = backend.forward_deepseek_v4_mixed( + q=q, + positions=torch.arange(5, dtype=torch.int32), + token_to_kv_pool=SimpleNamespace(), + layer_id=0, + kind="mla", + compress_ratio=4, + num_local_heads=2, + padded_heads=2, + head_dim=4, + window_size=4, + softmax_scale=1.0, + attn_sink=torch.zeros(2), + topk_indices=topk, + ) + + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0][0], "prefill") + self.assertEqual(calls[0][1], 3) + self.assertEqual(calls[0][2], [0, 1, 2]) + self.assertEqual(calls[0][3], [[0, 1], [2, 3], [4, 5]]) + self.assertEqual(calls[0][4], [0]) + self.assertEqual(calls[0][5], [0, 0, 0]) + self.assertTrue(calls[0][6].is_extend()) + self.assertEqual(calls[1][0], "decode") + self.assertEqual(calls[1][1], 2) + self.assertEqual(calls[1][2], [3, 4]) + self.assertEqual(calls[1][3], [[6, 7], [8, 9]]) + self.assertEqual(calls[1][4], [1, 2]) + self.assertEqual(calls[1][5], [0, 1]) + self.assertTrue(calls[1][6].is_decode()) + self.assertTrue(torch.equal(out[:3], torch.ones((3, 2, 4)))) + self.assertTrue(torch.equal(out[3:], torch.full((2, 2, 4), 2.0))) + def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): backend = DeepseekV4AttentionBackend( SimpleNamespace( @@ -946,7 +1350,7 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): attn_tp_size=1, dtype=torch.bfloat16, head_dim=512, - context_len=4096, + context_len=128, ) ) seq_lens = torch.tensor([70, 3], dtype=torch.int32) @@ -964,7 +1368,7 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): [[1, 65, 3, -1], [0, -1, -1, -1]], dtype=torch.int32, ) - indices, lens = backend._decode_compressed_indices_and_lens( + indices, lens = backend._decode_compressed_attention_indices_and_lens( positions, compress_ratio=4, block_size=64, @@ -993,8 +1397,9 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): dtype=torch.int32, ), ) - indices, lens = backend._decode_compressed_indices_and_lens( - seq_lens.to(torch.int64) - 1, + hca_positions = seq_lens.to(torch.int64) - 1 + indices, lens = backend._decode_compressed_attention_indices_and_lens( + hca_positions, compress_ratio=128, block_size=64, topk_indices=None, @@ -1006,6 +1411,79 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): torch.tensor([[640, 641], [1280, -1]], dtype=torch.int32), ) ) + cached_indices, cached_lens = ( + backend._decode_compressed_attention_indices_and_lens( + hca_positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + ) + self.assertEqual(cached_indices.data_ptr(), indices.data_ptr()) + self.assertEqual(cached_lens.data_ptr(), lens.data_ptr()) + + def test_deepseek_v4_decode_backend_capture_ignores_warmup_cache(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for capture cache semantics") + device = torch.device("cuda") + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cuda", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + ) + ) + seq_lens = torch.tensor([128, 64], device=device, dtype=torch.int32) + backend.init_forward_metadata( + bs=2, + num_tokens=2, + req_pool_indices=torch.tensor([0, 1], device=device, dtype=torch.int64), + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + req_to_page=torch.tensor( + [[10, 11], [20, 21]], + device=device, + dtype=torch.int32, + ), + ) + positions = seq_lens.to(torch.int64) - 1 + + warmup_indices, _ = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + metadata = backend.forward_metadata + key = next(iter(metadata.decode_dense_compressed_indices_cache.keys())) + metadata.decode_dense_compressed_indices_capture_safe_keys.clear() + + original_capturing = torch.cuda.is_current_stream_capturing + torch.cuda.is_current_stream_capturing = lambda: True + try: + capture_indices, _ = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + reused_indices, _ = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + finally: + torch.cuda.is_current_stream_capturing = original_capturing + + self.assertNotEqual(capture_indices.data_ptr(), warmup_indices.data_ptr()) + self.assertEqual(reused_indices.data_ptr(), capture_indices.data_ptr()) + self.assertIn(key, metadata.decode_dense_compressed_indices_capture_safe_keys) def test_deepseek_v4_c128a_prefill_local_compressed_indices_contract(self): backend = DeepseekV4AttentionBackend( @@ -1130,6 +1608,264 @@ def test_deepseek_v4_indexer_mxfp4_gather_reuses_workspace(self): self.assertTrue(torch.equal(values, expected_values)) self.assertTrue(torch.equal(scales, expected_scales)) + def test_deepseek_v4_decode_backend_masks_padding_tokens(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + ) + ) + seq_lens = torch.tensor([70, 3], dtype=torch.int32) + backend.init_forward_metadata( + bs=2, + num_tokens=2, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + req_to_page=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + ) + metadata = backend.forward_metadata + metadata.is_valid_token = torch.tensor([True, False]) + positions = seq_lens.to(torch.int64) - 1 + + topk_indices = torch.tensor( + [[1, 65, 3, -1], [0, -1, -1, -1]], + dtype=torch.int32, + ) + _, csa_lens = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=4, + block_size=64, + topk_indices=topk_indices, + ) + _, hca_lens = backend._decode_compressed_attention_indices_and_lens( + torch.tensor([255, 128], dtype=torch.int64), + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + + self.assertTrue(torch.equal(csa_lens, torch.tensor([3, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(hca_lens, torch.tensor([2, 0], dtype=torch.int32))) + + def test_deepseek_v4_global_topk_cpu_masks_invalid_req_before_indexing(self): + indices, lens = deepseek_v4_compute_global_topk_indices_and_lens( + topk_indices=torch.tensor([[0, 4], [0, 1]], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 99], dtype=torch.int32), + block_table=torch.tensor([[10]], dtype=torch.int32), + block_size=4, + is_valid_token=torch.tensor([True, False]), + ) + + self.assertTrue( + torch.equal( + indices, + torch.tensor([[40, -1], [-1, -1]], dtype=torch.int32), + ) + ) + self.assertTrue(torch.equal(lens, torch.tensor([1, 0], dtype=torch.int32))) + + def test_deepseek_v4_cuda_graph_replay_marks_padding_tokens_invalid(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + ) + ) + backend.init_cuda_graph_state(max_bs=4) + backend.init_forward_metadata_capture_cuda_graph( + bs=4, + num_tokens=4, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + forward_mode=ForwardMode.DECODE, + ) + + backend.init_forward_metadata_replay_cuda_graph( + bs=4, + actual_bs=2, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.tensor([70, 3, 1, 1], dtype=torch.int32), + forward_mode=ForwardMode.DECODE, + req_to_page=torch.tensor( + [ + [10, 11], + [20, 21], + [30, 31], + [40, 41], + ], + dtype=torch.int32, + ), + ) + + metadata = backend.forward_metadata + self.assertTrue( + torch.equal( + metadata.is_valid_token, + torch.tensor([True, True, False, False]), + ) + ) + self.assertEqual(metadata.decode_token_count(), 4) + + def test_deepseek_v4_indexer_metadata_refresh_masks_padding_tokens(self): + key = (4, 4, 3) + metadata = DeepseekV4ForwardMetadata( + page_size=64, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + block_table=torch.tensor([[10, 11], [20, 21], [30, 31]], dtype=torch.int32), + seq_lens=torch.tensor([9, 5, 3], dtype=torch.int32), + query_lens=torch.ones(3, dtype=torch.int32), + query_start_loc=torch.tensor([0, 1, 2, 3], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + is_valid_token=torch.tensor([True, False, True]), + forward_mode=ForwardMode.DECODE, + ) + plan = SimpleNamespace( + context_lens=torch.empty((3, 1), dtype=torch.int32), + block_table=torch.empty((3, 2), dtype=torch.int32), + max_context_len=0, + ) + metadata.decode_indexer_plan_cache[key] = plan + + def fake_compute(**kwargs): + kwargs["out_context_lens"].copy_( + torch.tensor([[2], [2], [1]], dtype=torch.int32) + ) + kwargs["out_block_tables"].copy_( + torch.tensor([[10, 11], [20, 21], [30, 31]], dtype=torch.int32) + ) + + with patch.object( + deepseek_v4_backend, + "deepseek_v4_indexer_decode_metadata_compute", + side_effect=fake_compute, + ): + deepseek_v4_backend._refresh_decode_indexer_plan_cache( + metadata, + max_context_len=256, + ) + + self.assertTrue( + torch.equal( + plan.context_lens, + torch.tensor([[2], [0], [1]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + plan.block_table, + torch.tensor([[10, 11], [0, 0], [30, 31]], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_indexer_decode_metadata_accepts_sliced_valid_mask(self): + metadata = SimpleNamespace( + decode_indexer_plan_cache={}, + decode_indexer_plan_refreshed_keys=set(), + ) + + def fake_compute(**kwargs): + kwargs["out_context_lens"].copy_( + torch.tensor([[2], [2]], dtype=torch.int32) + ) + kwargs["out_block_tables"].copy_( + torch.tensor([[10], [20]], dtype=torch.int32) + ) + + with patch.object( + deepseek_v4_model, + "deepseek_v4_indexer_decode_metadata_compute", + side_effect=fake_compute, + ): + plan = deepseek_v4_model._deepseek_v4_indexer_decode_metadata( + positions=torch.tensor([8, 4], dtype=torch.int64), + token_to_req_indices=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + cache_block_size=4, + compress_ratio=4, + metadata=metadata, + is_valid_token=torch.tensor([False, True]), + ) + + self.assertTrue( + torch.equal( + plan.context_lens, + torch.tensor([[0], [2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + plan.block_table, + torch.tensor([[0], [20]], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_indexer_schedule_refresh_uses_decode_plan_lens(self): + captured = {} + + def fake_get_metadata(context_lens, cache_block_size, num_sms): + captured["context_lens"] = context_lens.clone() + captured["cache_block_size"] = cache_block_size + captured["num_sms"] = num_sms + return torch.full((2, 1), 9, dtype=torch.int32) + + fake_deep_gemm = SimpleNamespace( + get_paged_mqa_logits_metadata=fake_get_metadata, + get_num_sms=lambda: 123, + ) + key = (4, 4, 2) + metadata = DeepseekV4ForwardMetadata( + page_size=64, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[0], [0]], dtype=torch.int32), + seq_lens=torch.tensor([5, 1], dtype=torch.int32), + query_lens=torch.tensor([1, 1], dtype=torch.int32), + query_start_loc=torch.tensor([0, 1, 2], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 1], dtype=torch.int32), + is_valid_token=torch.tensor([True, False]), + forward_mode=ForwardMode.DECODE, + ) + metadata.decode_indexer_plan_cache[key] = SimpleNamespace( + context_lens=torch.zeros((2, 1), dtype=torch.int32), + ) + metadata.decode_indexer_schedule_metadata[key] = torch.zeros( + (2, 1), + dtype=torch.int32, + ) + + with patch( + "tokenspeed_kernel.thirdparty.deep_gemm", + fake_deep_gemm, + create=True, + ): + deepseek_v4_backend._refresh_decode_indexer_schedule_metadata(metadata) + + self.assertTrue( + torch.equal( + captured["context_lens"], torch.zeros((2, 1), dtype=torch.int32) + ) + ) + self.assertEqual(captured["cache_block_size"], 4) + self.assertEqual(captured["num_sms"], 123) + self.assertTrue( + torch.equal( + metadata.decode_indexer_schedule_metadata[key], + torch.full((2, 1), 9, dtype=torch.int32), + ) + ) + def test_deepseek_v4_indexer_decode_batches_cache_reads(self): torch.manual_seed(0) positions = torch.tensor([15, 7, 3], dtype=torch.int64) @@ -1220,6 +1956,104 @@ def test_deepseek_v4_indexer_topk_reuses_output_buffer(self): self.assertTrue(torch.equal(actual[0].sort().values, torch.tensor([1, 2]))) self.assertTrue(torch.equal(actual[1].sort().values, torch.tensor([0, 3]))) + def test_deepseek_v4_indexer_topk_accepts_decode_lens_shape(self): + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf")], + [4.0, 1.0, 2.0, 3.0], + ], + dtype=torch.float32, + ) + lengths = torch.tensor([[3], [4]], dtype=torch.int32) + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + lengths, + topk_tokens=2, + next_n=1, + ) + + self.assertEqual(actual.shape, (2, 2)) + self.assertTrue(torch.equal(actual[0].sort().values, torch.tensor([1, 2]))) + self.assertTrue(torch.equal(actual[1].sort().values, torch.tensor([0, 3]))) + + def test_deepseek_v4_indexer_topk_can_sort_preserved_order(self): + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf")], + [4.0, 1.0, 2.0, 3.0], + ], + dtype=torch.float32, + ) + lengths = torch.tensor([3, 4], dtype=torch.int32) + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + lengths, + topk_tokens=4, + preserve_topk_order=True, + sort_preserved_topk=True, + ) + + self.assertTrue(torch.equal(actual[0], torch.tensor([1, 2, 0, -1]))) + self.assertTrue(torch.equal(actual[1], torch.tensor([0, 3, 2, 1]))) + + def test_deepseek_v4_indexer_topk_handles_shifted_prefill_rows(self): + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf"), -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), -float("inf"), 2.0, 8.0, 5.0], + ], + dtype=torch.float32, + ) + row_starts = torch.tensor([0, 3], dtype=torch.int32) + row_ends = torch.tensor([3, 6], dtype=torch.int32) + lengths = row_ends - row_starts + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + lengths, + topk_tokens=3, + preserve_topk_order=True, + sort_preserved_topk=True, + row_starts=row_starts, + row_ends=row_ends, + ) + + self.assertTrue(torch.equal(actual[0], torch.tensor([1, 2, 0]))) + self.assertTrue(torch.equal(actual[1], torch.tensor([1, 2, 0]))) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA is required") + def test_deepseek_v4_indexer_topk_uses_local_prefill_op(self): + if not _deepseek_v4_prefill_topk_op_available(): + self.skipTest("TRT-LLM indexer_topk_prefill is unavailable") + + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf"), -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), -float("inf"), 2.0, 8.0, 5.0], + ], + device="cuda", + dtype=torch.float32, + ) + row_starts = torch.tensor([0, 3], device="cuda", dtype=torch.int32) + row_ends = torch.tensor([3, 6], device="cuda", dtype=torch.int32) + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + row_ends - row_starts, + topk_tokens=4, + preserve_topk_order=True, + row_starts=row_starts, + row_ends=row_ends, + ) + + expected = torch.tensor( + [[0, 1, 2, -1], [0, 1, 2, -1]], + dtype=torch.int32, + ) + self.assertTrue(torch.equal(actual.cpu(), expected)) + def test_deepseek_v4_topk_buffer_grows_and_reuses(self): buffer = _DeepseekV4TopKBuffer(topk_tokens=3) @@ -1233,6 +2067,162 @@ def test_deepseek_v4_topk_buffer_grows_and_reuses(self): self.assertEqual(third.shape, (4, 3)) self.assertGreaterEqual(buffer.buffer.shape[0], 4) + def test_deepseek_v4_sparse_indexer_custom_op_registered(self): + self.assertTrue( + hasattr(torch.ops.tokenspeed, "deepseek_v4_sparse_attn_indexer") + ) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA is required") + def test_deepseek_v4_sparse_indexer_custom_op_fallback_covers_decode_tokens(self): + device = torch.device("cuda") + n_head = 2 + head_dim = 4 + total_tokens = 3 + + class FakeLinear: + def __init__(self, out_features): + self.out_features = out_features + + def __call__(self, x): + return ( + torch.zeros( + (x.shape[0], self.out_features), + device=x.device, + dtype=x.dtype, + ), + None, + ) + + self_obj = SimpleNamespace( + use_fp4_cache=True, + wq_b=FakeLinear(n_head * head_dim), + weights_proj=FakeLinear(n_head), + n_head=n_head, + head_dim=head_dim, + softmax_scale=1.0, + compress_ratio=4, + topk_tokens=2, + topk_buffer=None, + _prefill_gather_workspace=lambda rows, device: ( + torch.empty((0, 0), dtype=torch.uint8, device=device), + torch.empty((0, 0), dtype=torch.uint8, device=device), + ), + ) + metadata = SimpleNamespace( + forward_mode=ForwardMode.MIXED, + num_prefill_tokens=1, + num_prefill_reqs=1, + seq_lens_cpu=torch.tensor([4], dtype=torch.int32), + query_lens_cpu=torch.tensor([1], dtype=torch.int32), + token_to_req_indices=torch.tensor( + [0, 0, 0], dtype=torch.int32, device=device + ), + compressed_block_table=lambda compress_ratio, block_size: torch.zeros( + (1, 1), + dtype=torch.int32, + device=device, + ), + decode_token_count=lambda: 2, + ) + captured = {} + + def fake_prepare_mxfp4(**kwargs): + index_q = kwargs["index_q"] + rows = index_q.shape[0] + return ( + ( + torch.empty( + (rows, n_head, head_dim // 2), dtype=torch.uint8, device=device + ), + torch.empty((rows, n_head, 1), dtype=torch.uint8, device=device), + ), + torch.empty((rows, n_head), dtype=torch.float32, device=device), + ) + + def fake_prepare_reference(**kwargs): + captured["reference_rows"] = kwargs["positions"].numel() + rows = kwargs["positions"].numel() + return ( + torch.empty( + (rows, n_head, head_dim), dtype=torch.float32, device=device + ), + torch.empty((rows, n_head), dtype=torch.float32, device=device), + ) + + def fake_sparse_indexer(**kwargs): + captured["fallback_rows"] = kwargs["fallback_index_q"].shape[0] + captured["has_packed_q"] = kwargs["has_packed_q"] + captured["num_prefill_tokens"] = kwargs["num_prefill_tokens"] + captured["num_decode_tokens"] = kwargs["num_decode_tokens"] + return torch.full( + (total_tokens, self_obj.topk_tokens), + 7, + dtype=torch.int32, + device=device, + ) + + empty_prefill_metadata = SimpleNamespace( + chunk_bounds=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + chunk_plan=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + slots=torch.empty(0, dtype=torch.int64, device=device), + cu_seq_lens=torch.empty(0, dtype=torch.int32, device=device), + cu_start=torch.empty(0, dtype=torch.int32, device=device), + cu_end=torch.empty(0, dtype=torch.int32, device=device), + row_lens=torch.empty(0, dtype=torch.int32, device=device), + ) + decode_metadata = SimpleNamespace( + context_lens=torch.ones((2, 1), dtype=torch.int32, device=device), + block_table=torch.zeros((2, 1), dtype=torch.int32, device=device), + max_context_len=1, + ) + + with patch.object( + deepseek_v4_model, + "deepseek_v4_prepare_indexer_q_mxfp4", + side_effect=fake_prepare_mxfp4, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_deepgemm_fp4_indexer_available", + return_value=False, + ), patch.object( + deepseek_v4_model, + "deepseek_v4_prepare_indexer_q_reference", + side_effect=fake_prepare_reference, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_prefill_metadata", + return_value=empty_prefill_metadata, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_decode_metadata", + return_value=decode_metadata, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_decode_schedule_metadata", + return_value=None, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_sparse_attn_indexer", + side_effect=fake_sparse_indexer, + ): + actual = DeepseekV4Indexer._forward_sparse_indexer_custom_op( + self_obj, + hidden_states=torch.zeros((total_tokens, 8), device=device), + qr=torch.zeros((total_tokens, 8), device=device), + positions=torch.arange(total_tokens, dtype=torch.int64, device=device), + metadata=metadata, + indexer_cache=torch.empty((1, 1), dtype=torch.uint8, device=device), + indexer_block_size=1, + cos_sin_cache=torch.empty((1, 1), device=device), + ) + + self.assertEqual(tuple(actual.shape), (total_tokens, self_obj.topk_tokens)) + self.assertEqual(captured["reference_rows"], total_tokens) + self.assertEqual(captured["fallback_rows"], total_tokens) + self.assertFalse(captured["has_packed_q"]) + self.assertEqual(captured["num_prefill_tokens"], 1) + self.assertEqual(captured["num_decode_tokens"], 2) + def test_deepseek_v4_indexer_prefill_topk_chunks_cap_logits_bytes(self): positions = torch.tensor([3, 7, 11, 15], dtype=torch.int64) @@ -1261,6 +2251,257 @@ def test_deepseek_v4_indexer_prefill_topk_chunks_cap_logits_bytes(self): [(0, 1)], ) + def test_deepseek_v4_indexer_prefill_topk_chunks_use_cpu_lengths(self): + positions = torch.zeros(6, dtype=torch.int64) + + self.assertEqual( + _deepseek_v4_indexer_prefill_topk_chunks( + positions, + compress_ratio=4, + max_logits_bytes=16, + seq_lens_cpu=torch.tensor([12, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([4, 2], dtype=torch.int32), + ), + [(0, 2), (2, 3), (3, 4), (4, 6)], + ) + + def test_deepseek_v4_mixed_indexer_fallback_uses_compressed_block_table(self): + base_block_table = torch.tensor([[1]], dtype=torch.int32) + indexer_block_table = torch.tensor([[7]], dtype=torch.int32) + captured = {} + + class FakeLinear: + def __init__(self, out_features): + self.out_features = out_features + + def __call__(self, x): + return ( + torch.zeros( + (x.shape[0], self.out_features), + dtype=torch.float32, + device=x.device, + ), + None, + ) + + class FakeCompressor: + def __init__(self): + self.norm = SimpleNamespace( + weight=torch.ones(1), + variance_epsilon=1e-6, + ) + + def __call__(self, **kwargs): + return None + + pool = SimpleNamespace( + state_block_size=4, + get_indexer_state_buffer=lambda layer_id: torch.empty((1, 1)), + get_indexer_block_size=lambda layer_id: 4, + get_indexer_kv_buffer_2d=lambda layer_id: torch.empty((8, 128)), + ) + metadata = SimpleNamespace( + forward_mode=ForwardMode.MIXED, + indexer_state_block_table=None, + block_table=base_block_table, + token_to_req_indices=torch.tensor([0, 0], dtype=torch.int32), + compressed_block_table=( + lambda compress_ratio, block_size: indexer_block_table + ), + compressed_slot_mapping=lambda *args, **kwargs: torch.zeros( + 2, dtype=torch.int64 + ), + decode_token_count=lambda: 0, + num_prefill_tokens=2, + num_prefill_reqs=1, + seq_lens_cpu=torch.tensor([8], dtype=torch.int32), + query_lens_cpu=torch.tensor([2], dtype=torch.int32), + ) + ctx = SimpleNamespace( + token_to_kv_pool=pool, + attn_backend=SimpleNamespace(forward_metadata=metadata), + forward_mode=ForwardMode.MIXED, + ) + self_obj = SimpleNamespace( + use_fp4_cache=False, + compressor=FakeCompressor(), + compress_ratio=4, + n_head=1, + head_dim=4, + softmax_scale=1.0, + topk_tokens=2, + topk_buffer=None, + wq_b=FakeLinear(4), + weights_proj=FakeLinear(1), + _forward_sparse_indexer_custom_op=lambda **kwargs: None, + ) + + def fake_prepare_reference(**kwargs): + rows = kwargs["positions"].numel() + return ( + torch.zeros((rows, 1, 4), dtype=torch.float32), + torch.zeros((rows, 1), dtype=torch.float32), + ) + + def fake_topk_from_cache(**kwargs): + captured["block_table"] = kwargs["block_table"] + rows = kwargs["positions"].numel() + return torch.full((rows, 2), 3, dtype=torch.int32) + + with patch.object( + deepseek_v4_model, + "deepseek_v4_csa_indexer_cache_insert", + return_value=None, + ), patch.object( + deepseek_v4_model, + "deepseek_v4_prepare_indexer_q_reference", + side_effect=fake_prepare_reference, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_topk_from_cache_batched", + side_effect=fake_topk_from_cache, + ): + topk = DeepseekV4Indexer.forward( + self_obj, + hidden_states=torch.zeros((2, 8)), + qr=torch.zeros((2, 8)), + positions=torch.tensor([6, 7], dtype=torch.int64), + ctx=ctx, + out_cache_loc=torch.zeros(2, dtype=torch.int64), + layer_index=0, + cos_sin_cache=torch.empty((1, 1)), + ) + + self.assertTrue(torch.equal(captured["block_table"], indexer_block_table)) + self.assertTrue(torch.equal(topk, torch.full((2, 2), 3, dtype=torch.int32))) + + def test_deepseek_v4_indexer_prefill_gather_plan_reuses_request_k(self): + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_gather_plan( + positions=torch.tensor([0, 1, 5, 0, 3], dtype=torch.int64), + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32), + block_table=torch.tensor([[10], [20]], dtype=torch.int32), + cache_block_size=4, + compress_ratio=2, + ) + ) + + self.assertTrue(torch.equal(slots, torch.tensor([40, 41, 42, 80, 81]))) + self.assertTrue(torch.equal(cu_start, torch.tensor([0, 0, 0, 3, 3]))) + self.assertTrue(torch.equal(cu_end, torch.tensor([0, 1, 3, 3, 5]))) + self.assertTrue(torch.equal(row_lens, torch.tensor([0, 1, 3, 0, 2]))) + self.assertEqual(max_len, 3) + + def test_deepseek_v4_indexer_prefill_request_chunks_match_reference(self): + chunks = _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=torch.tensor([16], dtype=torch.int32), + query_lens_cpu=torch.tensor([6], dtype=torch.int32), + compress_ratio=4, + num_tokens=6, + max_logits_bytes=32, + workspace_size=100, + ) + + self.assertEqual( + [ + ( + c.req_start, + c.req_end, + c.query_start, + c.query_end, + c.token_start, + c.token_end, + c.skip_kv_gather, + ) + for c in chunks + ], + [ + (0, 1, 0, 2, 0, 2, False), + (0, 1, 2, 4, 2, 4, True), + (0, 1, 4, 6, 4, 6, True), + ], + ) + + chunks = _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=torch.tensor([16, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([2, 2], dtype=torch.int32), + compress_ratio=4, + num_tokens=4, + max_logits_bytes=128, + workspace_size=100, + ) + + self.assertEqual(len(chunks), 1) + self.assertEqual((chunks[0].req_start, chunks[0].req_end), (0, 2)) + self.assertEqual((chunks[0].token_start, chunks[0].token_end), (0, 4)) + self.assertFalse(chunks[0].skip_kv_gather) + + def test_deepseek_v4_indexer_prefill_request_gather_plan_matches_reference(self): + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=torch.tensor([16, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([4, 2], dtype=torch.int32), + block_table=torch.tensor([[10], [20]], dtype=torch.int32), + cache_block_size=4, + compress_ratio=4, + req_start=0, + req_end=2, + query_start=1, + query_end=5, + ) + ) + + self.assertTrue(torch.equal(slots, torch.tensor([40, 41, 42, 43, 80, 81]))) + self.assertTrue(torch.equal(cu_start, torch.tensor([0, 0, 0, 4]))) + self.assertTrue(torch.equal(cu_end, torch.tensor([3, 3, 4, 5]))) + self.assertTrue(torch.equal(row_lens, torch.tensor([3, 3, 4, 1]))) + self.assertEqual(max_len, 4) + + def test_deepseek_v4_indexer_prefill_metadata_packs_and_caches_plan(self): + metadata = SimpleNamespace( + seq_lens_cpu=torch.tensor([16, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([4, 2], dtype=torch.int32), + num_prefill_reqs=2, + prefill_indexer_plan_cache={}, + ) + block_table = torch.tensor([[10], [20]], dtype=torch.int32) + + actual = _deepseek_v4_indexer_prefill_metadata( + metadata=metadata, + block_table=block_table, + cache_block_size=4, + compress_ratio=4, + num_prefill_tokens=6, + ) + cached = _deepseek_v4_indexer_prefill_metadata( + metadata=metadata, + block_table=block_table, + cache_block_size=4, + compress_ratio=4, + num_prefill_tokens=6, + ) + + self.assertIs(actual, cached) + self.assertTrue( + torch.equal( + actual.chunk_bounds, + torch.tensor([[0, 6, 0, 2, 0, 6, 0]], dtype=torch.int64), + ) + ) + self.assertTrue( + torch.equal( + actual.chunk_plan, + torch.tensor([[0, 6, 0, 6, 4, 0, 3]], dtype=torch.int64), + ) + ) + self.assertEqual(actual.slots.numel(), 0) + self.assertTrue( + torch.equal(actual.cu_seq_lens, torch.tensor([0, 4, 6], dtype=torch.int32)) + ) + self.assertTrue(torch.equal(actual.cu_start, torch.tensor([0, 0, 0, 0, 4, 4]))) + self.assertTrue(torch.equal(actual.cu_end, torch.tensor([3, 3, 3, 4, 5, 6]))) + self.assertTrue(torch.equal(actual.row_lens, torch.tensor([3, 3, 3, 4, 1, 2]))) + def test_hidden_compression_helpers_preserve_expected_shapes(self): import torch @@ -1456,6 +2697,8 @@ def test_deepseek_v4_gate_fallback_returns_fp32_logits(self): topk_method=None, ) gate = DeepseekV4MoEGate(config, layer_index=1) + with torch.no_grad(): + gate.weight.copy_(torch.randn_like(gate.weight)) hidden_states = torch.randn(3, config.hidden_size) logits = gate(hidden_states) @@ -1634,38 +2877,6 @@ def test_packed_topk_router_logits_recover_weights_after_softmax(self): self.assertTrue(torch.allclose(recovered, topk_weights)) - def test_mxfp4_scale_dtype_preserves_e8m0_checkpoint_bits(self): - import torch - - if not hasattr(torch, "float8_e8m0fnu"): - self.skipTest("float8_e8m0fnu is unavailable") - - loaded = torch.tensor( - [[0.0078125, 0.015625], [0.03125, 0.0625]], dtype=torch.float32 - ).to(torch.float8_e8m0fnu) - param = torch.empty_like(loaded, dtype=MXFP4_SCALE_DTYPE) - param.copy_(loaded) - - self.assertEqual(MXFP4_SCALE_DTYPE, torch.float8_e8m0fnu) - self.assertTrue(torch.equal(param.view(torch.uint8), loaded.view(torch.uint8))) - - def test_mxfp4_triton_scale_layout_uses_uint8_view_for_e8m0(self): - import torch - - if not hasattr(torch, "float8_e8m0fnu"): - self.skipTest("float8_e8m0fnu is unavailable") - - scale = torch.tensor( - [[0.0078125, 0.015625], [0.03125, 0.0625]], dtype=torch.float32 - ).to(torch.float8_e8m0fnu) - - layout_scale = _mxfp4_scale_for_layout(scale) - self.assertEqual(layout_scale.dtype, torch.uint8) - self.assertTrue(torch.equal(layout_scale, scale.view(torch.uint8))) - - uint8_scale = scale.view(torch.uint8) - self.assertIs(_mxfp4_scale_for_layout(uint8_scale), uint8_scale) - def test_mxfp4_flashinfer_reorders_w1w3_halves_for_trtllm(self): import torch diff --git a/test/runtime/test_generation_output_processor.py b/test/runtime/test_generation_output_processor.py new file mode 100644 index 000000000..91c1724cc --- /dev/null +++ b/test/runtime/test_generation_output_processor.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import torch + +from tokenspeed.runtime.engine.generation_output_processor import ( + OutputProcesser, + RequestState, +) +from tokenspeed.runtime.sampling.sampling_params import SamplingParams + + +class _Sender: + def __init__(self): + self.items = [] + + def send_pyobj(self, obj): + self.items.append(obj) + + +class _Tokenizer: + eos_token_id = None + additional_stop_token_ids = None + + def decode(self, ids): + return "".join(str(i) for i in ids) + + +class _Metrics: + enabled = False + + +class _ForwardOp: + request_ids = ["prefill", "decode"] + request_pool_indices = [0, 1] + input_lengths = [4, 1] + extend_prefix_lens = [0] + + def num_extends(self): + return 1 + + +class _ExecutionResult: + output_tokens = torch.tensor([11, 22], dtype=torch.int32) + output_lengths = torch.tensor([1, 1], dtype=torch.int32) + output_logprobs = None + grammar_completion = None + + def sync(self): + return None + + +def _state(input_ids: list[int], *, computed_length: int = 0) -> RequestState: + state = RequestState( + prompt_input_ids=input_ids, + sampling_params=SamplingParams(max_new_tokens=8, stop=[], ignore_eos=True), + stream=False, + tokenizer=_Tokenizer(), + ) + state.computed_length = computed_length + return state + + +def test_mixed_forward_updates_reserve_for_decode_slots_only(): + sender = _Sender() + processor = OutputProcesser( + sender, + global_rank=0, + metrics=_Metrics(), + ) + processor.rid_to_state["prefill"] = _state([1, 2, 3, 4]) + processor.rid_to_state["decode"] = _state([5, 6, 7], computed_length=3) + + events = processor.post_process_forward_op(_ForwardOp(), _ExecutionResult()) + + reserve_events = [ + event for event in events if type(event).__name__ == "UpdateReserveNumTokens" + ] + assert len(reserve_events) == 1 + assert reserve_events[0].request_id == "decode" + assert reserve_events[0].reserve_num_tokens_in_next_schedule_event == 1 diff --git a/test/runtime/test_inline_detokenizer_receiver.py b/test/runtime/test_inline_detokenizer_receiver.py index de7295c78..7346a3c91 100644 --- a/test/runtime/test_inline_detokenizer_receiver.py +++ b/test/runtime/test_inline_detokenizer_receiver.py @@ -252,7 +252,12 @@ def test_flag_off_receiver_does_not_take_inline_branch(self): # Populate output_ids so the raw-token fallback path doesn't crash; # we only care that the inline branch is NOT taken. tokens = tok.encode("hello world") - recv = _batch_token_id_out(["r1"], decode_ids=[tokens], output_ids=[tokens]) + recv = _batch_token_id_out( + ["r1"], + decode_ids=[tokens], + output_ids=[tokens], + batch_accept_draft_tokens=[1.5], + ) mgr.output_processor.handle_batch_output(recv) out = state.collector.take() @@ -264,6 +269,7 @@ def test_flag_off_receiver_does_not_take_inline_branch(self): # conversion used to guarantee. The state machine that would have # populated ``state.text`` never ran, so the value is "". self.assertEqual(out["text"], "") + self.assertEqual(out["meta_info"]["accept_draft_tokens"], 1.5) self.assertIsNone(state.inline_detokenizer) self.assertEqual(state.text, "") diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py index ee4962c37..8f44aa6a4 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py @@ -24,11 +24,13 @@ platform = current_platform() flash_mla_with_kvcache = error_fn +flash_mla_sparse_fwd = error_fn get_mla_metadata = error_fn if platform.is_nvidia and platform.is_hopper: try: from flash_mla import ( + flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata, ) @@ -39,4 +41,4 @@ # Direct export # ------------------------------------------------------------------------------ -__all__ = ["flash_mla_with_kvcache", "get_mla_metadata"] +__all__ = ["flash_mla_sparse_fwd", "flash_mla_with_kvcache", "get_mla_metadata"] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py new file mode 100644 index 000000000..8b0773ab2 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py @@ -0,0 +1,115 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton + +__all__ = ["deepseek_v4_indexer_decode_metadata_compute"] + + +@triton.jit +def _deepseek_v4_indexer_decode_metadata_kernel( + out_block_tables_ptr, + out_block_tables_stride, + out_context_lens_ptr, + positions_ptr, + token_to_req_indices_ptr, + block_table_ptr, + block_table_stride, + rows: tl.constexpr, + cols: tl.constexpr, + compress_ratio: tl.constexpr, + cache_block_size: tl.constexpr, + max_blocks: tl.constexpr, + candidate_block: tl.constexpr, +): + token_idx = tl.program_id(0) + pos = tl.load(positions_ptr + token_idx).to(tl.int64) + compressed_lens = tl.maximum((pos + 1) // compress_ratio, 0) + req = tl.load(token_to_req_indices_ptr + token_idx).to(tl.int32) + req_valid = (req >= 0) & (req < rows) + safe_req = tl.maximum(0, tl.minimum(req, rows - 1)) + num_valid_pages = tl.zeros((), dtype=tl.int64) + for col_start in range(0, max_blocks, candidate_block): + col_offsets = col_start + tl.arange(0, candidate_block) + col_mask = col_offsets < max_blocks + in_cols = col_offsets < cols + safe_col = tl.where(in_cols, col_offsets, 0) + bt_load_mask = col_mask & in_cols & req_valid + bt_vals = tl.load( + block_table_ptr + safe_req * block_table_stride + safe_col, + mask=bt_load_mask, + other=0, + ) + page_valid = (bt_vals >= 0) & in_cols + final_mask = page_valid & req_valid & col_mask + masked_bt = tl.where(final_mask, bt_vals, 0) + tl.store( + out_block_tables_ptr + token_idx * out_block_tables_stride + col_offsets, + masked_bt, + mask=col_mask, + ) + num_valid_pages += tl.sum(final_mask.to(tl.int64), axis=0) + available_lens = num_valid_pages * cache_block_size + context_len_val = tl.minimum(compressed_lens, available_lens) + context_len_val = tl.where(req_valid, context_len_val, 0) + tl.store(out_context_lens_ptr + token_idx, context_len_val.to(tl.int32)) + + +def deepseek_v4_indexer_decode_metadata_compute( + *, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + max_blocks: int, + out_context_lens: torch.Tensor, + out_block_tables: torch.Tensor, +) -> None: + """Build decode-indexer context lengths and block tables in one Triton pass.""" + num_tokens = int(positions.shape[0]) if positions.ndim >= 1 else 0 + if num_tokens == 0: + return + if out_context_lens.dtype != torch.int32 or out_block_tables.dtype != torch.int32: + raise TypeError("output buffers must be int32") + positions_i64 = positions.to(torch.int64) + token_to_req_indices_i32 = token_to_req_indices.to(torch.int32) + block_table_i32 = block_table.to(torch.int32) + rows = int(block_table.shape[0]) if block_table.ndim >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.ndim >= 2 else 0 + candidate_block = min(1024, max(16, triton.next_power_of_2(max_blocks))) + _deepseek_v4_indexer_decode_metadata_kernel[(num_tokens,)]( + out_block_tables, + out_block_tables.stride(0), + out_context_lens, + positions_i64, + token_to_req_indices_i32, + block_table_i32, + block_table_i32.stride(0), + rows=rows, + cols=cols, + compress_ratio=int(compress_ratio), + cache_block_size=int(cache_block_size), + max_blocks=int(max_blocks), + candidate_block=candidate_block, + ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py index a010c16a7..92e09a04f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py @@ -44,11 +44,180 @@ "moe_align_block_size", "moe_sum_reduce_torch_compile", "moe_sum_reduce_triton", + "stage_deepseek_v4_mega_moe_inputs", ] padding_size = 128 if bool(int(os.getenv("TOKENSPEED_MOE_PADDING", "0"))) else 0 +# --------------------------------------------------------------------------- +# DeepSeek V4 MegaMoE staging +# --------------------------------------------------------------------------- + + +_DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE = 128 + + +@triton.jit +def _deepseek_v4_stage_mega_moe_inputs_kernel( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_K: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K + hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(hidden_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + scale = amax / 448.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( + tl.uint32 + ) + scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) + rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + scaled = hidden_groups * (1.0 / rounded_scale)[:, None] + scaled = tl.reshape(scaled, [BLOCK_K]) + fp8 = scaled.to(tl.float8e4nv) + tl.store( + x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, + fp8, + mask=k_mask, + ) + + scale_offsets = tl.arange(0, num_groups) + packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) + tl.store( + x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, + packed_scale, + ) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_offsets * topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def stage_deepseek_v4_mega_moe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp8: torch.Tensor, + x_sf: torch.Tensor, + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % _DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + f"a multiple of {_DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE}." + ) + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = _DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(topk_ids.shape[1]) + _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp8.stride(0), + x_fp8.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + topk_ids.shape[1], + BLOCK_K=block_k, + GROUP_K=32, + BLOCK_TOPK=block_topk, + num_warps=4, + ) + + # --------------------------------------------------------------------------- # Routing (top-k) # --------------------------------------------------------------------------- diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu index d51c978d0..3fe93c4e6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu @@ -18,7 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. // -// DeepSeek V4 fused SWA cache insert. +// DeepSeek V4 fused SWA cache insert and sparse attention/indexer helpers. // // Cache layout per paged block: // [0, block_size * 576): token data, each token [448 fp8 bytes | 64 bf16/fp16] @@ -49,6 +49,76 @@ constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; constexpr int kThreads = 256; constexpr float kFp8Max = 448.0f; +template +__global__ void gather_paged_indexer_mxfp4_cache_kernel( + const uint8_t* __restrict__ kv_cache, + uint8_t* __restrict__ values_out, + uint8_t* __restrict__ scales_out, + const int32_t* __restrict__ block_table, + const int32_t* __restrict__ cu_seq_lens, + int batch_size, + int num_tokens, + int value_bytes, + int scale_bytes, + int cache_block_size, + int64_t cache_block_stride, + int64_t value_stride, + int64_t scale_stride, + int64_t block_table_stride) { + constexpr int kVecBytes = sizeof(uint4); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * kVecBytes; + + __shared__ int batch_idx[BlockYSize]; + if (threadIdx.x == 0) { + batch_idx[threadIdx.y] = -1; + } + __syncthreads(); + + for (int iter = 0; iter < (batch_size + blockDim.x - 1) / blockDim.x; + ++iter) { + const int req = iter * blockDim.x + threadIdx.x; + if (req < batch_size) { + const int seq_start = cu_seq_lens[req]; + const int seq_end = cu_seq_lens[req + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = req; + } + } + } + __syncthreads(); + + const int req = batch_idx[threadIdx.y]; + if (token_idx >= num_tokens || req < 0) { + return; + } + + const int in_req_token_idx = token_idx - cu_seq_lens[req]; + const int block_idx = + block_table[static_cast(req) * block_table_stride + + in_req_token_idx / cache_block_size]; + const int block_offset = in_req_token_idx % cache_block_size; + const int64_t block_base = static_cast(block_idx) * cache_block_stride; + + if (head_idx < value_bytes) { + const int64_t value_src = + block_base + static_cast(block_offset) * value_bytes + head_idx; + const int64_t value_dst = + static_cast(token_idx) * value_stride + head_idx; + *reinterpret_cast(values_out + value_dst) = + *reinterpret_cast(kv_cache + value_src); + } + + if (blockIdx.y == 0 && threadIdx.x == 0) { + const int64_t scale_src = + block_base + static_cast(cache_block_size) * value_bytes + + static_cast(block_offset) * scale_bytes; + const int64_t scale_dst = static_cast(token_idx) * scale_stride; + *reinterpret_cast(scales_out + scale_dst) = + *reinterpret_cast(kv_cache + scale_src); + } +} + template __device__ __forceinline__ float scalar_to_float(scalar_t value); @@ -179,8 +249,8 @@ __global__ void fused_qnorm_rope_kv_insert_kernel( } __syncthreads(); - // Match vLLM's numeric contract: materialize K at activation dtype before - // the UE8M0 absmax and final cache write. + // Match the reference cache writer by materializing K at activation dtype + // before the UE8M0 absmax and final cache write. for (int dim = tid; dim < kHeadDim; dim += blockDim.x) { values[dim] = scalar_to_float(float_to_scalar(values[dim])); } @@ -249,6 +319,102 @@ void launch_fused_qnorm_rope_kv_insert( } // namespace +void deepseek_v4_gather_paged_indexer_mxfp4_cache(TensorView kv_cache, + TensorView values_out, + TensorView scales_out, + TensorView block_table, + TensorView cu_seq_lens, + int64_t cache_block_size) { + CHECK_CUDA(kv_cache); + CHECK_CUDA(values_out); + CHECK_CUDA(scales_out); + CHECK_CUDA(block_table); + CHECK_CUDA(cu_seq_lens); + CHECK_DIM(2, kv_cache); + CHECK_DIM(2, values_out); + CHECK_DIM(2, scales_out); + CHECK_DIM(2, block_table); + CHECK_DIM(1, cu_seq_lens); + + TVM_FFI_ICHECK(kv_cache.dtype() == dl_uint8) << "kv_cache must be uint8"; + TVM_FFI_ICHECK(values_out.dtype() == dl_uint8) << "values_out must be uint8"; + TVM_FFI_ICHECK(scales_out.dtype() == dl_uint8) << "scales_out must be uint8"; + TVM_FFI_ICHECK(block_table.dtype() == dl_int32) + << "block_table must be int32"; + TVM_FFI_ICHECK(cu_seq_lens.dtype() == dl_int32) + << "cu_seq_lens must be int32"; + TVM_FFI_ICHECK(kv_cache.stride(1) == 1) << "kv_cache last dim must be contiguous"; + TVM_FFI_ICHECK(values_out.stride(1) == 1) + << "values_out last dim must be contiguous"; + TVM_FFI_ICHECK(scales_out.stride(1) == 1) + << "scales_out last dim must be contiguous"; + TVM_FFI_ICHECK(cache_block_size > 0) << "cache_block_size must be positive"; + TVM_FFI_ICHECK(cu_seq_lens.size(0) == block_table.size(0) + 1) + << "cu_seq_lens must have batch_size + 1 entries"; + + const int batch_size = static_cast(block_table.size(0)); + const int num_tokens = static_cast(values_out.size(0)); + TVM_FFI_ICHECK(scales_out.size(0) >= num_tokens) + << "scales_out must cover values_out rows"; + // Output rows may be an exact length or a conservative upper bound, so do + // not read cu_seq_lens[-1] on host here. The kernel only writes rows covered + // by device-side cu_seq_lens. + if (batch_size == 0 || num_tokens == 0) { + return; + } + const int value_bytes = static_cast(values_out.size(1)); + const int scale_bytes = static_cast(scales_out.size(1)); + TVM_FFI_ICHECK(value_bytes > 0 && value_bytes % static_cast(sizeof(uint4)) == 0) + << "values_out width must be a positive multiple of 16 bytes"; + TVM_FFI_ICHECK(scale_bytes > 0) << "scales_out width must be positive"; + TVM_FFI_ICHECK(scale_bytes == static_cast(sizeof(uint32_t))) + << "paged indexer MXFP4 gather expects 4 scale bytes per row"; + TVM_FFI_ICHECK(kv_cache.size(1) >= cache_block_size * (value_bytes + scale_bytes)) + << "kv_cache block stride is too small for indexer MXFP4 rows"; + + cudaSetDevice(kv_cache.device().device_id); + const cudaStream_t stream = get_stream(kv_cache.device()); + constexpr int kBlockX = 8; + constexpr int kVecBytes = sizeof(uint4); + const int grid_y = (value_bytes + kBlockX * kVecBytes - 1) / (kBlockX * kVecBytes); + +#define LAUNCH_PAGED_GATHER(BLOCK_Y) \ + do { \ + const dim3 grid((num_tokens + (BLOCK_Y)-1) / (BLOCK_Y), grid_y); \ + const dim3 block(kBlockX, (BLOCK_Y)); \ + gather_paged_indexer_mxfp4_cache_kernel<(BLOCK_Y)> \ + <<>>( \ + static_cast(kv_cache.data_ptr()), \ + static_cast(values_out.data_ptr()), \ + static_cast(scales_out.data_ptr()), \ + static_cast(block_table.data_ptr()), \ + static_cast(cu_seq_lens.data_ptr()), batch_size, \ + num_tokens, value_bytes, scale_bytes, \ + static_cast(cache_block_size), kv_cache.stride(0), \ + values_out.stride(0), scales_out.stride(0), block_table.stride(0)); \ + } while (0) + + if (num_tokens < 32) { + LAUNCH_PAGED_GATHER(1); + } else if (num_tokens < 64) { + LAUNCH_PAGED_GATHER(2); + } else if (num_tokens < 128) { + LAUNCH_PAGED_GATHER(4); + } else if (num_tokens < 256) { + LAUNCH_PAGED_GATHER(8); + } else if (num_tokens < 512) { + LAUNCH_PAGED_GATHER(16); + } else { + LAUNCH_PAGED_GATHER(32); + } +#undef LAUNCH_PAGED_GATHER + + cudaError_t status = cudaGetLastError(); + TVM_FFI_ICHECK(status == cudaSuccess) + << "deepseek_v4_gather_paged_indexer_mxfp4_cache failed: " + << cudaGetErrorString(status); +} + void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( TensorView q, TensorView kv, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu index c17011b00..058d30bab 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu @@ -38,7 +38,16 @@ void deepseek_v4_indexer_topk_prefill(TensorView logits, TensorView output, int64_t k); +void deepseek_v4_gather_paged_indexer_mxfp4_cache(TensorView kv_cache, + TensorView values_out, + TensorView scales_out, + TensorView block_table, + TensorView cu_seq_lens, + int64_t cache_block_size); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert, fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); TVM_FFI_DLL_EXPORT_TYPED_FUNC(deepseek_v4_indexer_topk_prefill, deepseek_v4_indexer_topk_prefill); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(deepseek_v4_gather_paged_indexer_mxfp4_cache, + deepseek_v4_gather_paged_indexer_mxfp4_cache); diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py index 2962ca068..40f9c6bf9 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py @@ -40,6 +40,14 @@ def has_indexer_topk_prefill() -> bool: return hasattr(module, "deepseek_v4_indexer_topk_prefill") +def has_indexer_mxfp4_paged_gather() -> bool: + try: + module = _load_deepseek_v4_attention_module() + except Exception: + return False + return hasattr(module, "deepseek_v4_gather_paged_indexer_mxfp4_cache") + + def fused_qnorm_rope_kv_insert( q: torch.Tensor, kv: torch.Tensor, @@ -97,3 +105,36 @@ def indexer_topk_prefill( output, int(k), ) + + +def indexer_mxfp4_paged_gather( + kv_cache: torch.Tensor, + values_out: torch.Tensor, + scales_out: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + cache_block_size: int, +) -> None: + if kv_cache.dtype != torch.uint8: + raise TypeError(f"kv_cache must be uint8, got {kv_cache.dtype}") + if values_out.dtype != torch.uint8: + raise TypeError(f"values_out must be uint8, got {values_out.dtype}") + if scales_out.dtype != torch.uint8: + raise TypeError(f"scales_out must be uint8, got {scales_out.dtype}") + if block_table.dtype != torch.int32: + block_table = block_table.to(torch.int32) + if cu_seq_lens.dtype != torch.int32: + cu_seq_lens = cu_seq_lens.to(torch.int32) + if values_out.shape[0] != scales_out.shape[0]: + raise ValueError( + "DeepSeek V4 paged gather output value/scale rows must match, " + f"got values={values_out.shape[0]}, scales={scales_out.shape[0]}" + ) + _load_deepseek_v4_attention_module().deepseek_v4_gather_paged_indexer_mxfp4_cache( + kv_cache, + values_out, + scales_out, + block_table.contiguous(), + cu_seq_lens.contiguous(), + int(cache_block_size), + ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py index a334cec51..cbf8103c0 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py @@ -141,13 +141,13 @@ def fast_topk_v2( topk: int, next_n: int = 1, ): - seq_lens = seq_lens.to(torch.int32).contiguous() + seq_lens = seq_lens.to(torch.int32).reshape(-1).contiguous() if next_n == 1: torch.ops.trtllm.indexer_topk_decode( values, seq_lens, indices, next_n, topk ) else: - row_ends = seq_lens.cumsum(0) + row_ends = torch.cumsum(seq_lens, dim=0, dtype=torch.int32) row_starts = row_ends - seq_lens torch.ops.trtllm.indexer_topk_prefill( values, row_starts, row_ends, indices, topk diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index 428d36faa..12f5e7dc6 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -218,6 +218,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("enable_l3_storage", &tokenspeed::SchedulerConfig::enable_l3_storage) .def_rw("prefetch_threshold", &tokenspeed::SchedulerConfig::prefetch_threshold) .def_rw("enable_kv_cache_events", &tokenspeed::SchedulerConfig::enable_kv_cache_events) + .def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode) .def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache) .def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba) .def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size) diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index de978d82e..c4f3d4a73 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -509,6 +509,10 @@ Scheduler::newForwardOperation(std::vector candidates) { } ops.push_back(std::move(op)); }; + auto has_prefill_op = [&]() { + return std::any_of(ops.begin(), ops.end(), + [](const ForwardOperation& op) { return std::holds_alternative(op); }); + }; std::vector loadback_ops; auto simulated_free = initialPagedCacheGroupSimulatedFree(); for (Request* request : candidates) { @@ -535,13 +539,13 @@ Scheduler::newForwardOperation(std::vector candidates) { } } else if (request->Is() || (request->Is() && config_.role != Role::kP)) { // Prefill-first: skip ALL decode if any prefill was scheduled this round. - if (!ops.empty() && std::holds_alternative(ops.back())) break; + if (!config_.enable_mixed_prefill_decode && has_prefill_op()) break; if (auto ev = scheduleDecode(request, simulated_free)) { push_op(applyEventAndGenerateOp(request, *ev)); } } else if (request->Is() && config_.role != Role::kP) { - if (!ops.empty() && std::holds_alternative(ops.back())) break; + if (!config_.enable_mixed_prefill_decode && has_prefill_op()) break; if (auto ev = scheduleDecodeFromRetracted(request, simulated_free)) { std::vector loadback_diff = ev->GetLoadbackDiff(); diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index d1b2173ba..e48d042bf 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -83,6 +83,7 @@ struct SchedulerConfig { bool enable_l3_storage{false}; std::int32_t prefetch_threshold{4}; // num pages bool enable_kv_cache_events{false}; + bool enable_mixed_prefill_decode{false}; std::int32_t num_pages_reserved_for_retracted_or_running{}; Role role{Role::kFused}; diff --git a/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py b/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py index 17175ad76..ab831137f 100644 --- a/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py +++ b/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py @@ -280,6 +280,26 @@ def test_decode_batch_only_when_no_prefill_work(self): assert plan.forward[0].num_extends() > 0 assert plan.forward[0].request_ids == ["r1"] + def test_mixed_prefill_decode_can_schedule_decode_with_new_prefill(self): + cfg = make_config(max_scheduled_tokens=512, max_batch_size=8) + cfg.enable_mixed_prefill_decode = True + s = Scheduler(cfg) + + submit(s, "r0", list(range(8))) + s.next_execution_plan() # r0 → PrefillDone + s.next_execution_plan() # r0 → Decoding + advance_forward(s, "r0", tokens=[99]) + + submit(s, "r1", list(range(8))) + plan = s.next_execution_plan() + op = plan.forward[0] + + assert op.request_ids == ["r1", "r0"] + assert op.num_extends() == 1 + assert len(op.input_ids) == sum(op.input_lengths[: op.num_extends()]) + assert len(op.input_ids) + len(op.decode_input_ids) == sum(op.input_lengths) + assert op.sizes == [1, 0] + def test_max_batch_size_limits_scheduled_requests(self): """max_batch_size caps the number of requests per plan.""" s = Scheduler(make_config(max_scheduled_tokens=512, max_batch_size=2)) @@ -728,9 +748,9 @@ def test_retract_recovered_carries_last_prefill_token(self): def test_mixed_batch_decode_input_ids_length(self): """decode_input_ids has one entry per decode request; all -1 for normal decodes.""" - s = Scheduler( - make_config(page_size=16, num_device_pages=1024, max_batch_size=8) - ) + cfg = make_config(page_size=16, num_device_pages=1024, max_batch_size=8) + cfg.enable_mixed_prefill_decode = True + s = Scheduler(cfg) # Bring r0 to Decoding. submit(s, "r0", list(range(8))) s.next_execution_plan() # r0 → PrefillDone