Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pegainfer-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ pub use pegainfer_kernels::ops::{
silu_mul_fused_batch_into,
};
pub use sampling::{
argmax, argmax_batch_bf16_into, flashinfer_topk_row_states_bytes, gpu_sample, gpu_sample_into,
argmax, argmax_batch_bf16_indexed_into, argmax_batch_bf16_into,
flashinfer_topk_row_states_bytes, gpu_sample, gpu_sample_into, select_batch_tokens_into,
};
#[cfg(feature = "kernel-call-trace")]
pub use traced::{
Expand Down
89 changes: 86 additions & 3 deletions pegainfer-core/src/ops/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use anyhow::Result;
use anyhow::{Result, anyhow};
use cudarc::driver::CudaSlice;

use crate::sampler::SamplingParams;
use crate::tensor::{DeviceContext, DeviceVec};
use crate::tensor::{DeviceContext, DeviceVec, HiddenStates};

pub use pegainfer_kernels::ops::{
argmax, argmax_batch_bf16_into, flashinfer_topk_row_states_bytes,
argmax, argmax_batch_bf16_indexed_into, argmax_batch_bf16_into,
flashinfer_topk_row_states_bytes,
};

/// GPU sampling: temperature -> softmax -> top-k -> top-p -> multinomial.
Expand Down Expand Up @@ -60,3 +61,85 @@ pub fn gpu_sample_into(
random_val,
)
}

/// Pick the next token for each row in a decode batch.
///
/// Greedy rows are selected together with indexed batched argmax. Non-greedy
/// rows still use the existing per-row sampler because each row may have its
/// own random value and sampling parameters.
#[allow(clippy::too_many_arguments)]
pub fn select_batch_tokens_into(
ctx: &DeviceContext,
logits: &HiddenStates,
params: &[&SamplingParams],
random_vals: &[f32],
row_indices_scratch: &mut CudaSlice<i32>,
probs_scratch: &mut CudaSlice<f32>,
top1_value_scratch: &mut CudaSlice<half::bf16>,
row_states_scratch: &mut CudaSlice<u8>,
valid_scratch: &mut CudaSlice<u8>,
out: &mut CudaSlice<i32>,
) -> Result<Vec<u32>> {
let batch_size = params.len();
let mut tokens = vec![0; batch_size];
let greedy_rows = params
.iter()
.enumerate()
.filter_map(|(i, params_i)| params_i.is_greedy().then_some(i as i32))
.collect::<Vec<_>>();

if !greedy_rows.is_empty() {
// Batch sampling for greedy rows.
if row_indices_scratch.len() < greedy_rows.len() {
return Err(anyhow!(
"row_indices_scratch too small: have {}, need {}",
row_indices_scratch.len(),
greedy_rows.len()
));
}

ctx.stream
.memcpy_htod(&greedy_rows, row_indices_scratch)
.map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?;
Comment on lines +91 to +103

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a defensive check to ensure row_indices_scratch is large enough to hold greedy_rows before calling memcpy_htod. Currently, the size check is performed inside argmax_batch_bf16_indexed_into, which is too late because memcpy_htod will already have failed or panicked if the host slice is larger than the device slice.

    if !greedy_rows.is_empty() {
        if row_indices_scratch.len() < greedy_rows.len() {
            return Err(anyhow!(
                "row_indices_scratch too small: have {}, need {}",
                row_indices_scratch.len(),
                greedy_rows.len()
            ));
        }
        // Batch sampling for greedy rows.
        ctx.stream
            .memcpy_htod(&greedy_rows, row_indices_scratch)
            .map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?;


argmax_batch_bf16_indexed_into(
ctx,
logits,
row_indices_scratch,
greedy_rows.len(),
top1_value_scratch,
out,
)?;

let out_host = ctx
.stream
.clone_dtoh(out)
.map_err(|e| anyhow!("D2H indexed batch argmax read failed: {}", e))?;
ctx.sync()?;

for (i, row) in greedy_rows.iter().enumerate() {
tokens[*row as usize] = out_host[i] as u32;
}
}

// Per-row sampling for non-greedy rows.
for (i, params_i) in params.iter().enumerate() {
if params_i.is_greedy() {
continue;
}
let logits_i = pegainfer_kernels::ops::extract_vec(ctx, logits, i)?;
tokens[i] = gpu_sample_into(
ctx,
&logits_i,
probs_scratch,
top1_value_scratch,
row_states_scratch,
valid_scratch,
out,
params_i,
random_vals[i],
)?;
}

Ok(tokens)
}
11 changes: 3 additions & 8 deletions pegainfer-deepseek-v2-lite/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ use std::{
};

use anyhow::{Context, Result};
use pegainfer_engine::{
engine::{EngineHandle, EngineLoadOptions, FinishReason, GenerateRequest, TokenEvent},
sampler::SamplingParams,
use pegainfer_engine::engine::{
EngineHandle, EngineLoadOptions, FinishReason, GenerateRequest, TokenEvent,
};
use tokio::sync::mpsc;

Expand Down Expand Up @@ -42,7 +41,7 @@ fn handle_request(generator: &mut DeepSeekV2LiteEp2Generator, req: &GenerateRequ
logprobs: vec![None; prompt_tokens],
});
}
if !is_greedy(req.params) {
if !req.params.is_greedy() {
reject_request(
req,
prompt_tokens,
Expand Down Expand Up @@ -110,10 +109,6 @@ fn emit_generation_result(
});
}

fn is_greedy(params: SamplingParams) -> bool {
(params.temperature <= 0.0 || params.top_k == 1) && params.top_p >= 1.0
}

fn unix_time_secs() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
Expand Down
1 change: 1 addition & 0 deletions pegainfer-kernels/KERNELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Use this file as the LLM entrypoint before editing kernels. Start from `op_id`,
| `shared.linear.gemm_per_token` | model-specific decode accuracy gates | `ops::gemm_per_token` / `ops::gemm_per_token_into_checked` | `gemm_per_token_cuda` | `csrc/shared/linear.cu` | cuBLAS | computes each row through the N=1 decode GEMM boundary; used when row-wise parity is required before performance optimization |
| `shared.sampling.argmax_batch_bf16` | batched greedy gates | `ops::argmax_batch_bf16_into` | `argmax_batch_bf16_cuda` | `csrc/shared/argmax.cu` | CUDA | one greedy top-1 result per row over contiguous `HiddenStates` logits |
| `shared.elementwise.accumulate_bf16_token_scaled_to_f32` | DeepSeek-V2-Lite NCCL device combine | `ops::accumulate_bf16_token_scaled_to_f32_into` | `accumulate_bf16_token_scaled_to_f32_cuda` | `csrc/shared/elementwise.cu` | CUDA | accumulates one bf16 expert-output token into a selected row of reusable f32 device scratch before the NCCL combine all-reduce |
| `shared.sampling.argmax_batch_bf16_indexed` | selected batched greedy gates | `ops::argmax_batch_bf16_indexed_into` | `argmax_batch_bf16_indexed_cuda` | `csrc/shared/argmax.cu` | CUDA | compact greedy top-1 results for selected source rows over `HiddenStates` logits |

## Qwen3-4B Dense Full-Attention Path

Expand Down
59 changes: 59 additions & 0 deletions pegainfer-kernels/csrc/shared/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,55 @@ __global__ void argmax_batch_bf16_kernel(
}
}

__global__ void argmax_batch_bf16_indexed_kernel(
const __nv_bfloat16* __restrict__ x,
const int* __restrict__ row_indices,
__nv_bfloat16* __restrict__ values,
int* __restrict__ indices,
int rows,
int n) {
extern __shared__ char shared_mem[];
float* shared_vals = reinterpret_cast<float*>(shared_mem);
int* shared_idxs =
reinterpret_cast<int*>(shared_mem + blockDim.x * sizeof(float));

int row = blockIdx.x;
if (row >= rows) return;
int source_row = row_indices[row];
const __nv_bfloat16* row_x = x + static_cast<size_t>(source_row) * n;
int tid = threadIdx.x;

float local_max = -INFINITY;
int local_idx = 0;
for (int i = tid; i < n; i += blockDim.x) {
float val = __bfloat162float(row_x[i]);
if (argmax_better(val, i, local_max, local_idx)) {
local_max = val;
local_idx = i;
}
}
shared_vals[tid] = local_max;
shared_idxs[tid] = local_idx;
__syncthreads();

for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
float rhs_val = shared_vals[tid + s];
int rhs_idx = shared_idxs[tid + s];
if (argmax_better(rhs_val, rhs_idx, shared_vals[tid], shared_idxs[tid])) {
shared_vals[tid] = rhs_val;
shared_idxs[tid] = rhs_idx;
}
}
__syncthreads();
}

if (tid == 0) {
indices[row] = shared_idxs[0];
values[row] = __float2bfloat16(shared_vals[0]);
}
}

__global__ void argmax_batch_bf16_partial_kernel(
const __nv_bfloat16* __restrict__ x,
float* __restrict__ partial_values,
Expand Down Expand Up @@ -208,6 +257,16 @@ void argmax_batch_bf16_cuda(const __nv_bfloat16* x, __nv_bfloat16* values,
stream>>>(x, values, indices, rows, n);
}

void argmax_batch_bf16_indexed_cuda(const __nv_bfloat16* x,
const int* row_indices,
__nv_bfloat16* values, int* indices,
int rows, int n,
cudaStream_t stream) {
argmax_batch_bf16_indexed_kernel<<<rows, SAMPLE_BLOCK,
SAMPLE_BLOCK * (sizeof(float) + sizeof(int)),
stream>>>(x, row_indices, values, indices, rows, n);
}

void argmax_batch_bf16_split_cuda(const __nv_bfloat16* x, __nv_bfloat16* values,
int* indices, float* partial_values,
int* partial_indices, int rows, int n,
Expand Down
10 changes: 10 additions & 0 deletions pegainfer-kernels/src/ffi/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,16 @@ unsafe extern "C" {
stream: CUstream,
);

pub fn argmax_batch_bf16_indexed_cuda(
x: *const Half,
row_indices: *const i32,
values: *mut Half,
indices: *mut i32,
rows: i32,
n: i32,
stream: CUstream,
);

pub fn bf16_to_f32_cuda(
input: *const Half,
output: *mut f32,
Expand Down
4 changes: 2 additions & 2 deletions pegainfer-kernels/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub use norm::{
rms_norm_into, rms_norm_offset_into,
};
pub use sampling::{
BatchSamplingRow, BatchSamplingScratch, argmax, argmax_batch_bf16_into,
argmax_batch_bf16_split_partials_len, flashinfer_top1_batch_into,
BatchSamplingRow, BatchSamplingScratch, argmax, argmax_batch_bf16_indexed_into,
argmax_batch_bf16_into, argmax_batch_bf16_split_partials_len, flashinfer_top1_batch_into,
flashinfer_topk_row_states_bytes, gpu_sample, gpu_sample_batch_into, gpu_sample_into,
};
53 changes: 53 additions & 0 deletions pegainfer-kernels/src/ops/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,59 @@ pub fn argmax_batch_bf16_into(
Ok(())
}

pub fn argmax_batch_bf16_indexed_into(
ctx: &DeviceContext,
logits: &HiddenStates,
row_indices: &CudaSlice<i32>,
rows: usize,
values: &mut CudaSlice<half::bf16>,
out: &mut CudaSlice<i32>,
) -> Result<()> {
if rows == 0 {
return Err(anyhow!("argmax indexed batch requires at least one row"));
}
if row_indices.len() < rows {
return Err(anyhow!(
"argmax indexed row scratch too small: have {}, need {}",
row_indices.len(),
rows
));
}
if values.len() < rows {
return Err(anyhow!(
"argmax indexed values scratch too small: have {}, need {}",
values.len(),
rows
));
}
if out.len() < rows {
return Err(anyhow!(
"argmax indexed output too small: have {}, need {}",
out.len(),
rows
));
}

let (logits_ptr, _gl) = logits.data.device_ptr(&ctx.stream);
let (row_indices_ptr, _gr) = row_indices.device_ptr(&ctx.stream);
let (values_ptr, _gv) = values.device_ptr_mut(&ctx.stream);
let (out_ptr, _go) = out.device_ptr_mut(&ctx.stream);

unsafe {
ffi::argmax_batch_bf16_indexed_cuda(
logits_ptr as *const ffi::Half,
row_indices_ptr as *const i32,
values_ptr as *mut ffi::Half,
out_ptr as *mut i32,
rows as i32,
logits.hidden_dim as i32,
ctx.stream.cu_stream(),
);
}

Ok(())
}

pub fn argmax_batch_bf16_split_partials_len(rows: usize, vocab: usize) -> usize {
const TILE_ELEMS: usize = 4096;
rows * vocab.div_ceil(TILE_ELEMS)
Expand Down
Loading