feat(sampling): batch greedy decode token selection#307
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces indexed batched argmax sampling to optimize token selection for greedy rows in a decode batch. It implements a new CUDA kernel (argmax_batch_bf16_indexed_kernel) and introduces select_batch_tokens_into to group greedy rows together for batched argmax while falling back to per-row sampling for non-greedy rows. The Qwen3 and Qwen3.5 executors and buffers have been updated to utilize this new batched sampling logic. Feedback on these changes suggests adding a defensive size check for row_indices_scratch before performing the host-to-device copy to prevent potential panics, as well as correcting a typo and an incorrect file path in the KERNELS.md documentation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if !greedy_rows.is_empty() { | ||
| // Batch sampling for greedy rows. | ||
| ctx.stream | ||
| .memcpy_htod(&greedy_rows, row_indices_scratch) | ||
| .map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?; |
There was a problem hiding this comment.
Add a defensive check to ensure row_indices_scratch is large enough to hold greedy_rows before calling memcpy_htod. Currently, the size check is performed inside argmax_batch_bf16_indexed_into, which is too late because memcpy_htod will already have failed or panicked if the host slice is larger than the device slice.
if !greedy_rows.is_empty() {
if row_indices_scratch.len() < greedy_rows.len() {
return Err(anyhow!(
"row_indices_scratch too small: have {}, need {}",
row_indices_scratch.len(),
greedy_rows.len()
));
}
// Batch sampling for greedy rows.
ctx.stream
.memcpy_htod(&greedy_rows, row_indices_scratch)
.map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?;| | --- | --- | --- | --- | --- | --- | --- | | ||
| | `shared.linear.gemm_per_token` | model-specific decode accuracy gates | `ops::gemm_per_token` / `ops::gemm_per_token_into_checked` | `gemm_per_token_cuda` | `csrc/shared/linear.cu` | cuBLAS | computes each row through the N=1 decode GEMM boundary; used when row-wise parity is required before performance optimization | | ||
| | `shared.sampling.argmax_batch_bf16` | batched greedy gates | `ops::argmax_batch_bf16_into` | `argmax_batch_bf16_cuda` | `csrc/shared/argmax.cu` | CUDA | one greedy top-1 result per row over contiguous `HiddenStates` logits | | ||
| | `shared.sampling.argmax_batch_bf16_indexed` | seleted batched greedy gates | `ops::argmax_batch_bf16_indexed_into` | `argmax_batch_bf16_indexed_cuda` | `csrc/argmax.cu` | CUDA | compact greedy top-1 results for selected source rows over `HiddenStates` logits | |
There was a problem hiding this comment.
There is a typo in the description ("seleted" should be "selected") and the source file path is incorrect (csrc/argmax.cu should be csrc/shared/argmax.cu). Please update this row in the table to:
| shared.sampling.argmax_batch_bf16_indexed | selected batched greedy gates | ops::argmax_batch_bf16_indexed_into | argmax_batch_bf16_indexed_cuda | csrc/shared/argmax.cu | CUDA | compact greedy top-1 results for selected source rows over HiddenStates logits |
|
Thanks for the PR. I did a local A/B against origin/main on Qwen3-4B with CUDA Graph, prompt_len=1, output_len=256, warmup=5, iters=20. The optimization looks real: on my local run, steady decode TPOT improves from 12.861ms to 12.377ms at bs=16, and from 13.812ms to 12.569ms at bs=32. The generated token hash stayed identical across main and this PR, so this does not look like dead code. There is one merge blocker though: the PR currently does not compile as-is. Could you add the missing extern declaration and also clean up the now-unused |
6c33967 to
a38f4c7
Compare
a38f4c7 to
1444810
Compare
|
Yep, all set. Tests passed via |
xiaguan
left a comment
There was a problem hiding this comment.
LGTM. Verified release check plus Qwen3 hf_golden_gate and Qwen3.5 e2e_scheduler locally.
Description
Qwen decode sampling was per-row: bs=32 issued 32 kernel launches + 32 D2H syncs per step, scaling linearly with batch size. This PR switches to batched argmax: all-greedy batches get 1 launch + 1 copy per step; mixed batches use the batched path for greedy rows and fall back to per-row only for rows that need randomness.
Fixes: #242
Type of Change
Changes
pegainfer-kernels/csrc/argmax.cu— new indexed argmax path (argmax_batch_bf16_indexed_cuda/argmax_batch_bf16_indexed_into) that compacts greedy rows for mixed batchespegainfer-core/src/ops/sampling.rs+pegainfer-kernels/src/ops/sampling.rs— batched greedy sampling, shared by both Qwen cratespegainfer-qwen3-4b/src/executor.rs— batched path integrationpegainfer-qwen35-4b/src/batch_decode.rs— remove per-row loop, use unified pathTest Results
Correctness Gate
hf_golden_gatepasses — 816 positions, 6528 head deltas, within bf16 tolerance:Performance (Qwen3-4B, sm_89, BF16, CUDA Graph)
prompt_len=1, output_len=256, warmup=5, iters=20, seed=42.
Acceptance Criteria
Checklist