Skip to content

feat(sampling): batch greedy decode token selection#307

Merged
xiaguan merged 3 commits into
openinfer-project:mainfrom
Ke-Wng:feat/batched-greedy-token-selection
Jun 9, 2026
Merged

feat(sampling): batch greedy decode token selection#307
xiaguan merged 3 commits into
openinfer-project:mainfrom
Ke-Wng:feat/batched-greedy-token-selection

Conversation

@Ke-Wng

@Ke-Wng Ke-Wng commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Description

Qwen decode sampling was per-row: bs=32 issued 32 kernel launches + 32 D2H syncs per step, scaling linearly with batch size. This PR switches to batched argmax: all-greedy batches get 1 launch + 1 copy per step; mixed batches use the batched path for greedy rows and fall back to per-row only for rows that need randomness.

Fixes: #242

Type of Change

  • New feature (non-breaking change which adds functionality)

Changes

  • pegainfer-kernels/csrc/argmax.cu — new indexed argmax path (argmax_batch_bf16_indexed_cuda / argmax_batch_bf16_indexed_into) that compacts greedy rows for mixed batches
  • pegainfer-core/src/ops/sampling.rs + pegainfer-kernels/src/ops/sampling.rs — batched greedy sampling, shared by both Qwen crates
  • pegainfer-qwen3-4b/src/executor.rs — batched path integration
  • pegainfer-qwen35-4b/src/batch_decode.rs — remove per-row loop, use unified path
  • Scratch buffer size derived from kernel requirements

Test Results

Correctness Gate

hf_golden_gate passes — 816 positions, 6528 head deltas, within bf16 tolerance:

Mode Mean P99 Max
Sequential bs=1 eager 0.0319 0.1192 0.2638
Batched eager (9) 0.0311 0.1259 0.2500
CUDA-graph (9 padded) 0.0311 0.1259 0.2500
CUDA-graph (5 padded) 0.0295 0.1201 0.1440

Performance (Qwen3-4B, sm_89, BF16, CUDA Graph)

Run Baseline TPOT p50 Current TPOT p50 Delta Baseline tok/s Current tok/s
bs=1 9.49ms 9.50ms +0.1% 105.04 105.19
bs=16 12.73ms 11.39ms -10.5% 78.56 87.92
bs=32 14.54ms 11.69ms -19.6% 68.72 85.54

prompt_len=1, output_len=256, warmup=5, iters=20, seed=42.

Acceptance Criteria

  • Qwen3-4B HF logits golden gate stays green
  • bs≥16 TPOT improves measurably (bs=16: -10.5%, bs=32: -19.6%)
  • bs=1 unchanged (+0.1%)

Checklist

  • Code follows style guidelines
  • Self-review done
  • Commitizen conventions
  • Tests pass

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces indexed batched argmax sampling to optimize token selection for greedy rows in a decode batch. It implements a new CUDA kernel (argmax_batch_bf16_indexed_kernel) and introduces select_batch_tokens_into to group greedy rows together for batched argmax while falling back to per-row sampling for non-greedy rows. The Qwen3 and Qwen3.5 executors and buffers have been updated to utilize this new batched sampling logic. Feedback on these changes suggests adding a defensive size check for row_indices_scratch before performing the host-to-device copy to prevent potential panics, as well as correcting a typo and an incorrect file path in the KERNELS.md documentation.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +91 to +95
if !greedy_rows.is_empty() {
// Batch sampling for greedy rows.
ctx.stream
.memcpy_htod(&greedy_rows, row_indices_scratch)
.map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a defensive check to ensure row_indices_scratch is large enough to hold greedy_rows before calling memcpy_htod. Currently, the size check is performed inside argmax_batch_bf16_indexed_into, which is too late because memcpy_htod will already have failed or panicked if the host slice is larger than the device slice.

    if !greedy_rows.is_empty() {
        if row_indices_scratch.len() < greedy_rows.len() {
            return Err(anyhow!(
                "row_indices_scratch too small: have {}, need {}",
                row_indices_scratch.len(),
                greedy_rows.len()
            ));
        }
        // Batch sampling for greedy rows.
        ctx.stream
            .memcpy_htod(&greedy_rows, row_indices_scratch)
            .map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?;

Comment thread pegainfer-kernels/KERNELS.md Outdated
| --- | --- | --- | --- | --- | --- | --- |
| `shared.linear.gemm_per_token` | model-specific decode accuracy gates | `ops::gemm_per_token` / `ops::gemm_per_token_into_checked` | `gemm_per_token_cuda` | `csrc/shared/linear.cu` | cuBLAS | computes each row through the N=1 decode GEMM boundary; used when row-wise parity is required before performance optimization |
| `shared.sampling.argmax_batch_bf16` | batched greedy gates | `ops::argmax_batch_bf16_into` | `argmax_batch_bf16_cuda` | `csrc/shared/argmax.cu` | CUDA | one greedy top-1 result per row over contiguous `HiddenStates` logits |
| `shared.sampling.argmax_batch_bf16_indexed` | seleted batched greedy gates | `ops::argmax_batch_bf16_indexed_into` | `argmax_batch_bf16_indexed_cuda` | `csrc/argmax.cu` | CUDA | compact greedy top-1 results for selected source rows over `HiddenStates` logits |

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the description ("seleted" should be "selected") and the source file path is incorrect (csrc/argmax.cu should be csrc/shared/argmax.cu). Please update this row in the table to:

| shared.sampling.argmax_batch_bf16_indexed | selected batched greedy gates | ops::argmax_batch_bf16_indexed_into | argmax_batch_bf16_indexed_cuda | csrc/shared/argmax.cu | CUDA | compact greedy top-1 results for selected source rows over HiddenStates logits |

@xiaguan

xiaguan commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the PR. I did a local A/B against origin/main on Qwen3-4B with CUDA Graph, prompt_len=1, output_len=256, warmup=5, iters=20.

The optimization looks real: on my local run, steady decode TPOT improves from 12.861ms to 12.377ms at bs=16, and from 13.812ms to 12.569ms at bs=32. The generated token hash stayed identical across main and this PR, so this does not look like dead code.

There is one merge blocker though: the PR currently does not compile as-is. pegainfer-kernels/src/ops/sampling.rs calls ffi::argmax_batch_bf16_indexed_cuda, and the CUDA symbol is defined, but the Rust FFI declaration is missing from pegainfer-kernels/src/ffi/shared.rs.

Could you add the missing extern declaration and also clean up the now-unused SamplingParams import in the DeepSeek-V2-Lite engine? After that, I think this is a very promising direction, especially for reducing per-row sampling overhead in larger greedy decode batches.

@Ke-Wng Ke-Wng force-pushed the feat/batched-greedy-token-selection branch from 6c33967 to a38f4c7 Compare June 9, 2026 09:28
@Ke-Wng Ke-Wng force-pushed the feat/batched-greedy-token-selection branch from a38f4c7 to 1444810 Compare June 9, 2026 09:33
@Ke-Wng

Ke-Wng commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Yep, all set. Tests passed via cargo test --release -p pegainfer-qwen3-4b --test hf_golden_gate.

@xiaguan xiaguan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Verified release check plus Qwen3 hf_golden_gate and Qwen3.5 e2e_scheduler locally.

@xiaguan xiaguan merged commit 8b791d7 into openinfer-project:main Jun 9, 2026
xiaguan added a commit that referenced this pull request Jun 9, 2026
Roadmap-doc drift fixes for already-landed work (#307 batched greedy sampling, #316 in-process pegaflow KV offload). Doc-only: offload integration doc CLI status + Kimi/Qwen first-launch order, execution.md data-plane entry, qwen3 roadmap rows, index.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

sampling: batch decode runs one GPU launch + one sync per row per step

2 participants