Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/serving/deepseek-v4.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DP=4 + expert parallel + mega_moe + FP8 KV cache (B200, 4× SM100):

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \
--host localhost --port 30100 \
--host localhost --port 8000 \
--dist-init-addr 127.0.0.1:4013 \
--trust-remote-code \
--data-parallel-size 4 \
Expand Down Expand Up @@ -68,7 +68,7 @@ GSM8K 5-shot, 50 samples is the standard quick-validation harness for V4:
```bash
HF_DATASETS_TRUST_REMOTE_CODE=1 lm_eval run \
--model local-completions \
--model_args "model=deepseek-ai/DeepSeek-V4-Flash,base_url=http://127.0.0.1:30100/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=4,max_retries=1,timeout=600,max_gen_toks=256" \
--model_args "model=deepseek-ai/DeepSeek-V4-Flash,base_url=http://127.0.0.1:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=4,max_retries=1,timeout=600,max_gen_toks=256" \
--tasks gsm8k --num_fewshot 5 --limit 50 --batch_size 1 \
--gen_kwargs temperature=0
```
Expand Down
6 changes: 4 additions & 2 deletions python/tokenspeed/runtime/engine/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,12 @@ def prealloc_for_draft_decode(self, is_disaggregation_decode: bool = False):
return
out_cache_loc = torch.concat(out_cache_loc_list)
out_cache_loc = out_cache_loc.to(self.device, non_blocking=True)
req_indices = torch.tensor(req_indices, dtype=torch.int32).to(
req_indices = torch.tensor(req_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
start_offsets = self.req_to_token_pool.alloced_lens[req_indices]
start_offsets = torch.index_select(
self.req_to_token_pool.alloced_lens, 0, req_indices
)
end_offsets = start_offsets + num_tokens_pre_alloc
assign_req_to_token_pool[(bs,)](
req_indices,
Expand Down
85 changes: 65 additions & 20 deletions python/tokenspeed/runtime/execution/cache_loc_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,14 @@ def update_req_to_page(
def compute_out_cache_loc_kernel(
# Input pointers
req_pool_indices_ptr, # [batch_size]
input_lengths_ptr, # [batch_size]
valid_cache_lengths_ptr, # [req_pool_size+1]
input_lengths_ptr, # [batch_size] or None for uniform mode
cache_start_ptr, # [batch_size]
req_to_pages_ptr, # [req_pool_size+1, max_pages]
cumsum_lengths_ptr, # [batch_size] - cumulative sum of input_lengths for output offset
cumsum_lengths_ptr, # [batch_size] or None for uniform mode
# Output pointer
out_cache_loc_ptr, # [total_tokens]
# Scalars
uniform_input_length, # used when input_lengths_ptr is None
page_size: tl.constexpr,
max_pages: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
Expand All @@ -143,38 +144,47 @@ def compute_out_cache_loc_kernel(
Unified kernel to compute out_cache_loc for both prefill and decode.

For each token in each request, compute:
position = valid_cache_length + token_offset_in_seq
position = cache_start[req_idx] + token_offset_in_seq
page_idx = position // page_size
offset_in_page = position % page_size
page_id = req_to_pages[req_pool_idx, page_idx]
out_cache_loc = page_id * page_size + offset_in_page

For decode, input_lengths are all 1.
For prefill, input_lengths vary.

When all requests share the same input_length (the multi-step drafter
case), callers pass ``input_lengths_ptr=None`` (and ``cumsum_lengths_ptr=None``)
together with ``uniform_input_length`` set to the shared length. Triton
specializes the kernel on the None-ness of the pointers at JIT time and
dead-code-eliminates the corresponding GMEM reads.
"""
# Program ID represents which request we're processing
req_idx = tl.program_id(0)

# Load request metadata
# Load request metadata.
req_pool_idx = tl.load(req_pool_indices_ptr + req_idx)
seq_len = tl.load(input_lengths_ptr + req_idx)
valid_cache_len = tl.load(valid_cache_lengths_ptr + req_pool_idx)

# Get output offset from cumsum
# Always load from cumsum, use 0 index for first request to ensure type consistency
offset_idx = tl.where(req_idx > 0, req_idx - 1, 0)
output_offset = tl.load(cumsum_lengths_ptr + offset_idx)
# Zero out offset for first request
output_offset = tl.where(req_idx > 0, output_offset, 0)
valid_cache_len = tl.load(cache_start_ptr + req_idx)

if input_lengths_ptr is not None:
input_length = tl.load(input_lengths_ptr + req_idx)
# Always load from cumsum, use 0 index for first request to ensure type consistency
offset_idx = tl.where(req_idx > 0, req_idx - 1, 0)
output_offset = tl.load(cumsum_lengths_ptr + offset_idx)
# Zero out offset for first request
output_offset = tl.where(req_idx > 0, output_offset, 0)
else:
input_length = uniform_input_length
output_offset = req_idx * uniform_input_length

# Process tokens in blocks
num_blocks = tl.cdiv(seq_len, BLOCK_SIZE)
num_blocks = tl.cdiv(input_length, BLOCK_SIZE)
for block_idx in range(num_blocks):
block_start = block_idx * BLOCK_SIZE

# Compute token offsets within this block
token_offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = token_offsets < seq_len
mask = token_offsets < input_length

# Compute logical positions
positions = valid_cache_len + token_offsets
Expand All @@ -200,10 +210,10 @@ def compute_out_cache_loc(
out_cache_loc_ptr,
req_pool_indices: torch.Tensor, # [batch_size]
input_lengths: torch.Tensor, # [batch_size]
valid_cache_lengths: torch.Tensor, # [req_pool_size+1]
cache_start: torch.Tensor, # [batch_size]
req_to_pages: torch.Tensor, # [req_pool_size+1, max_pages]
page_size: int,
) -> torch.Tensor:
) -> None:
batch_size = req_pool_indices.shape[0]
max_pages = req_to_pages.shape[1]

Expand All @@ -215,10 +225,45 @@ def compute_out_cache_loc(
compute_out_cache_loc_kernel[grid](
req_pool_indices,
input_lengths,
valid_cache_lengths,
cache_start,
req_to_pages,
cumsum_lengths,
out_cache_loc_ptr,
0, # uniform_input_length unused when input_lengths_ptr is not None
page_size=page_size,
max_pages=max_pages,
BLOCK_SIZE=BLOCK_SIZE,
)


def compute_out_cache_loc_uniform(
out_cache_loc_ptr,
req_pool_indices: torch.Tensor, # [batch_size]
uniform_input_length: int,
cache_start: torch.Tensor, # [batch_size]
req_to_pages: torch.Tensor, # [req_pool_size+1, max_pages]
page_size: int,
) -> None:
"""Specialized entry point when every request has the same ``input_length``.

Skips the per-call ``torch.full`` + ``cumsum`` host-side work and the
corresponding GMEM reads inside the kernel. Used by the multi-step drafter
where each request decodes exactly ``spec_num_steps - 1`` tokens.
"""
batch_size = req_pool_indices.shape[0]
max_pages = req_to_pages.shape[1]

BLOCK_SIZE = 128
grid = (batch_size,)

compute_out_cache_loc_kernel[grid](
req_pool_indices,
None, # input_lengths_ptr is None → kernel uses uniform_input_length
cache_start,
req_to_pages,
None, # cumsum_lengths_ptr is None → kernel computes offset analytically
out_cache_loc_ptr,
uniform_input_length,
page_size=page_size,
max_pages=max_pages,
BLOCK_SIZE=BLOCK_SIZE,
Expand Down Expand Up @@ -261,7 +306,7 @@ def flatten_and_to_device(data, dtype=torch.int32):
forward_op.new_occupied_pages, dtype=torch.int32
)
request_pool_indices = flatten_and_to_device(
forward_op.request_pool_indices, dtype=torch.int32
forward_op.request_pool_indices, dtype=torch.int64
)
update_req_to_page(
req_to_page=req_to_page,
Expand Down
1 change: 1 addition & 0 deletions python/tokenspeed/runtime/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ class ForwardContext:

# --- logits processor ---
keep_full_logits: bool = False
last_index_offsets: torch.Tensor | None = None
4 changes: 2 additions & 2 deletions python/tokenspeed/runtime/execution/cuda_graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,13 @@ def __init__(
try:
draft_attn_backend.init_cuda_graph_state(
self.max_bs,
self.drafter.draft_seq_lens,
self.drafter.draft_seq_lens_buf,
paged_cache_group_specs=draft_paged_cache_group_specs,
max_tokens_per_req=self.max_tokens_per_req,
)
except TypeError:
draft_attn_backend.init_cuda_graph_state(
self.max_bs, self.drafter.draft_seq_lens
self.max_bs, self.drafter.draft_seq_lens_buf
)

self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
Expand Down
Loading
Loading