diff --git a/pegainfer-core/src/ops.rs b/pegainfer-core/src/ops.rs index 86b587f8..79780de7 100644 --- a/pegainfer-core/src/ops.rs +++ b/pegainfer-core/src/ops.rs @@ -32,7 +32,8 @@ pub use pegainfer_kernels::ops::{ silu_mul_fused_batch_into, }; pub use sampling::{ - argmax, argmax_batch_bf16_into, flashinfer_topk_row_states_bytes, gpu_sample, gpu_sample_into, + argmax, argmax_batch_bf16_indexed_into, argmax_batch_bf16_into, + flashinfer_topk_row_states_bytes, gpu_sample, gpu_sample_into, select_batch_tokens_into, }; #[cfg(feature = "kernel-call-trace")] pub use traced::{ diff --git a/pegainfer-core/src/ops/sampling.rs b/pegainfer-core/src/ops/sampling.rs index e0e89cd2..31231120 100644 --- a/pegainfer-core/src/ops/sampling.rs +++ b/pegainfer-core/src/ops/sampling.rs @@ -1,11 +1,12 @@ -use anyhow::Result; +use anyhow::{Result, anyhow}; use cudarc::driver::CudaSlice; use crate::sampler::SamplingParams; -use crate::tensor::{DeviceContext, DeviceVec}; +use crate::tensor::{DeviceContext, DeviceVec, HiddenStates}; pub use pegainfer_kernels::ops::{ - argmax, argmax_batch_bf16_into, flashinfer_topk_row_states_bytes, + argmax, argmax_batch_bf16_indexed_into, argmax_batch_bf16_into, + flashinfer_topk_row_states_bytes, }; /// GPU sampling: temperature -> softmax -> top-k -> top-p -> multinomial. @@ -60,3 +61,85 @@ pub fn gpu_sample_into( random_val, ) } + +/// Pick the next token for each row in a decode batch. +/// +/// Greedy rows are selected together with indexed batched argmax. Non-greedy +/// rows still use the existing per-row sampler because each row may have its +/// own random value and sampling parameters. +#[allow(clippy::too_many_arguments)] +pub fn select_batch_tokens_into( + ctx: &DeviceContext, + logits: &HiddenStates, + params: &[&SamplingParams], + random_vals: &[f32], + row_indices_scratch: &mut CudaSlice, + probs_scratch: &mut CudaSlice, + top1_value_scratch: &mut CudaSlice, + row_states_scratch: &mut CudaSlice, + valid_scratch: &mut CudaSlice, + out: &mut CudaSlice, +) -> Result> { + let batch_size = params.len(); + let mut tokens = vec![0; batch_size]; + let greedy_rows = params + .iter() + .enumerate() + .filter_map(|(i, params_i)| params_i.is_greedy().then_some(i as i32)) + .collect::>(); + + if !greedy_rows.is_empty() { + // Batch sampling for greedy rows. + if row_indices_scratch.len() < greedy_rows.len() { + return Err(anyhow!( + "row_indices_scratch too small: have {}, need {}", + row_indices_scratch.len(), + greedy_rows.len() + )); + } + + ctx.stream + .memcpy_htod(&greedy_rows, row_indices_scratch) + .map_err(|e| anyhow!("H2D indexed argmax rows failed: {}", e))?; + + argmax_batch_bf16_indexed_into( + ctx, + logits, + row_indices_scratch, + greedy_rows.len(), + top1_value_scratch, + out, + )?; + + let out_host = ctx + .stream + .clone_dtoh(out) + .map_err(|e| anyhow!("D2H indexed batch argmax read failed: {}", e))?; + ctx.sync()?; + + for (i, row) in greedy_rows.iter().enumerate() { + tokens[*row as usize] = out_host[i] as u32; + } + } + + // Per-row sampling for non-greedy rows. + for (i, params_i) in params.iter().enumerate() { + if params_i.is_greedy() { + continue; + } + let logits_i = pegainfer_kernels::ops::extract_vec(ctx, logits, i)?; + tokens[i] = gpu_sample_into( + ctx, + &logits_i, + probs_scratch, + top1_value_scratch, + row_states_scratch, + valid_scratch, + out, + params_i, + random_vals[i], + )?; + } + + Ok(tokens) +} diff --git a/pegainfer-deepseek-v2-lite/src/engine.rs b/pegainfer-deepseek-v2-lite/src/engine.rs index 4b995984..b0917aa8 100644 --- a/pegainfer-deepseek-v2-lite/src/engine.rs +++ b/pegainfer-deepseek-v2-lite/src/engine.rs @@ -4,9 +4,8 @@ use std::{ }; use anyhow::{Context, Result}; -use pegainfer_engine::{ - engine::{EngineHandle, EngineLoadOptions, FinishReason, GenerateRequest, TokenEvent}, - sampler::SamplingParams, +use pegainfer_engine::engine::{ + EngineHandle, EngineLoadOptions, FinishReason, GenerateRequest, TokenEvent, }; use tokio::sync::mpsc; @@ -42,7 +41,7 @@ fn handle_request(generator: &mut DeepSeekV2LiteEp2Generator, req: &GenerateRequ logprobs: vec![None; prompt_tokens], }); } - if !is_greedy(req.params) { + if !req.params.is_greedy() { reject_request( req, prompt_tokens, @@ -110,10 +109,6 @@ fn emit_generation_result( }); } -fn is_greedy(params: SamplingParams) -> bool { - (params.temperature <= 0.0 || params.top_k == 1) && params.top_p >= 1.0 -} - fn unix_time_secs() -> f64 { SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/pegainfer-kernels/KERNELS.md b/pegainfer-kernels/KERNELS.md index 40daf917..c93287d4 100644 --- a/pegainfer-kernels/KERNELS.md +++ b/pegainfer-kernels/KERNELS.md @@ -11,6 +11,7 @@ Use this file as the LLM entrypoint before editing kernels. Start from `op_id`, | `shared.linear.gemm_per_token` | model-specific decode accuracy gates | `ops::gemm_per_token` / `ops::gemm_per_token_into_checked` | `gemm_per_token_cuda` | `csrc/shared/linear.cu` | cuBLAS | computes each row through the N=1 decode GEMM boundary; used when row-wise parity is required before performance optimization | | `shared.sampling.argmax_batch_bf16` | batched greedy gates | `ops::argmax_batch_bf16_into` | `argmax_batch_bf16_cuda` | `csrc/shared/argmax.cu` | CUDA | one greedy top-1 result per row over contiguous `HiddenStates` logits | | `shared.elementwise.accumulate_bf16_token_scaled_to_f32` | DeepSeek-V2-Lite NCCL device combine | `ops::accumulate_bf16_token_scaled_to_f32_into` | `accumulate_bf16_token_scaled_to_f32_cuda` | `csrc/shared/elementwise.cu` | CUDA | accumulates one bf16 expert-output token into a selected row of reusable f32 device scratch before the NCCL combine all-reduce | +| `shared.sampling.argmax_batch_bf16_indexed` | selected batched greedy gates | `ops::argmax_batch_bf16_indexed_into` | `argmax_batch_bf16_indexed_cuda` | `csrc/shared/argmax.cu` | CUDA | compact greedy top-1 results for selected source rows over `HiddenStates` logits | ## Qwen3-4B Dense Full-Attention Path diff --git a/pegainfer-kernels/csrc/shared/argmax.cu b/pegainfer-kernels/csrc/shared/argmax.cu index 431e999e..a0cb2be1 100644 --- a/pegainfer-kernels/csrc/shared/argmax.cu +++ b/pegainfer-kernels/csrc/shared/argmax.cu @@ -93,6 +93,55 @@ __global__ void argmax_batch_bf16_kernel( } } +__global__ void argmax_batch_bf16_indexed_kernel( + const __nv_bfloat16* __restrict__ x, + const int* __restrict__ row_indices, + __nv_bfloat16* __restrict__ values, + int* __restrict__ indices, + int rows, + int n) { + extern __shared__ char shared_mem[]; + float* shared_vals = reinterpret_cast(shared_mem); + int* shared_idxs = + reinterpret_cast(shared_mem + blockDim.x * sizeof(float)); + + int row = blockIdx.x; + if (row >= rows) return; + int source_row = row_indices[row]; + const __nv_bfloat16* row_x = x + static_cast(source_row) * n; + int tid = threadIdx.x; + + float local_max = -INFINITY; + int local_idx = 0; + for (int i = tid; i < n; i += blockDim.x) { + float val = __bfloat162float(row_x[i]); + if (argmax_better(val, i, local_max, local_idx)) { + local_max = val; + local_idx = i; + } + } + shared_vals[tid] = local_max; + shared_idxs[tid] = local_idx; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + float rhs_val = shared_vals[tid + s]; + int rhs_idx = shared_idxs[tid + s]; + if (argmax_better(rhs_val, rhs_idx, shared_vals[tid], shared_idxs[tid])) { + shared_vals[tid] = rhs_val; + shared_idxs[tid] = rhs_idx; + } + } + __syncthreads(); + } + + if (tid == 0) { + indices[row] = shared_idxs[0]; + values[row] = __float2bfloat16(shared_vals[0]); + } +} + __global__ void argmax_batch_bf16_partial_kernel( const __nv_bfloat16* __restrict__ x, float* __restrict__ partial_values, @@ -208,6 +257,16 @@ void argmax_batch_bf16_cuda(const __nv_bfloat16* x, __nv_bfloat16* values, stream>>>(x, values, indices, rows, n); } +void argmax_batch_bf16_indexed_cuda(const __nv_bfloat16* x, + const int* row_indices, + __nv_bfloat16* values, int* indices, + int rows, int n, + cudaStream_t stream) { + argmax_batch_bf16_indexed_kernel<<>>(x, row_indices, values, indices, rows, n); +} + void argmax_batch_bf16_split_cuda(const __nv_bfloat16* x, __nv_bfloat16* values, int* indices, float* partial_values, int* partial_indices, int rows, int n, diff --git a/pegainfer-kernels/src/ffi/shared.rs b/pegainfer-kernels/src/ffi/shared.rs index 702f6256..1c7e7dc7 100644 --- a/pegainfer-kernels/src/ffi/shared.rs +++ b/pegainfer-kernels/src/ffi/shared.rs @@ -411,6 +411,16 @@ unsafe extern "C" { stream: CUstream, ); + pub fn argmax_batch_bf16_indexed_cuda( + x: *const Half, + row_indices: *const i32, + values: *mut Half, + indices: *mut i32, + rows: i32, + n: i32, + stream: CUstream, + ); + pub fn bf16_to_f32_cuda( input: *const Half, output: *mut f32, diff --git a/pegainfer-kernels/src/ops.rs b/pegainfer-kernels/src/ops.rs index da07b498..ffbc6f2c 100644 --- a/pegainfer-kernels/src/ops.rs +++ b/pegainfer-kernels/src/ops.rs @@ -46,7 +46,7 @@ pub use norm::{ rms_norm_into, rms_norm_offset_into, }; pub use sampling::{ - BatchSamplingRow, BatchSamplingScratch, argmax, argmax_batch_bf16_into, - argmax_batch_bf16_split_partials_len, flashinfer_top1_batch_into, + BatchSamplingRow, BatchSamplingScratch, argmax, argmax_batch_bf16_indexed_into, + argmax_batch_bf16_into, argmax_batch_bf16_split_partials_len, flashinfer_top1_batch_into, flashinfer_topk_row_states_bytes, gpu_sample, gpu_sample_batch_into, gpu_sample_into, }; diff --git a/pegainfer-kernels/src/ops/sampling.rs b/pegainfer-kernels/src/ops/sampling.rs index 23de7aee..095a4a6b 100644 --- a/pegainfer-kernels/src/ops/sampling.rs +++ b/pegainfer-kernels/src/ops/sampling.rs @@ -273,6 +273,59 @@ pub fn argmax_batch_bf16_into( Ok(()) } +pub fn argmax_batch_bf16_indexed_into( + ctx: &DeviceContext, + logits: &HiddenStates, + row_indices: &CudaSlice, + rows: usize, + values: &mut CudaSlice, + out: &mut CudaSlice, +) -> Result<()> { + if rows == 0 { + return Err(anyhow!("argmax indexed batch requires at least one row")); + } + if row_indices.len() < rows { + return Err(anyhow!( + "argmax indexed row scratch too small: have {}, need {}", + row_indices.len(), + rows + )); + } + if values.len() < rows { + return Err(anyhow!( + "argmax indexed values scratch too small: have {}, need {}", + values.len(), + rows + )); + } + if out.len() < rows { + return Err(anyhow!( + "argmax indexed output too small: have {}, need {}", + out.len(), + rows + )); + } + + let (logits_ptr, _gl) = logits.data.device_ptr(&ctx.stream); + let (row_indices_ptr, _gr) = row_indices.device_ptr(&ctx.stream); + let (values_ptr, _gv) = values.device_ptr_mut(&ctx.stream); + let (out_ptr, _go) = out.device_ptr_mut(&ctx.stream); + + unsafe { + ffi::argmax_batch_bf16_indexed_cuda( + logits_ptr as *const ffi::Half, + row_indices_ptr as *const i32, + values_ptr as *mut ffi::Half, + out_ptr as *mut i32, + rows as i32, + logits.hidden_dim as i32, + ctx.stream.cu_stream(), + ); + } + + Ok(()) +} + pub fn argmax_batch_bf16_split_partials_len(rows: usize, vocab: usize) -> usize { const TILE_ELEMS: usize = 4096; rows * vocab.div_ceil(TILE_ELEMS) diff --git a/pegainfer-qwen3-4b/src/executor.rs b/pegainfer-qwen3-4b/src/executor.rs index 10358a44..25e49de9 100644 --- a/pegainfer-qwen3-4b/src/executor.rs +++ b/pegainfer-qwen3-4b/src/executor.rs @@ -183,6 +183,43 @@ fn build_decode_request_results( Ok(outputs) } +fn build_batch_decode_request_results( + lane: &mut LocalQwen3Lane, + requests: &[DecodeStepItem], +) -> Result> { + let params: Vec<&SamplingParams> = requests.iter().map(|req| &req.params).collect(); + let random_vals: Vec = requests.iter().map(|req| req.random_val).collect(); + let tokens = pegainfer_core::ops::select_batch_tokens_into( + lane.model.device_ctx(), + &lane.bufs.logits, + ¶ms, + &random_vals, + &mut lane.sample_scratch.row_indices, + &mut lane.sample_scratch.probs, + &mut lane.sample_scratch.top1_values, + &mut lane.sample_scratch.row_states, + &mut lane.sample_scratch.valid, + &mut lane.sample_scratch.out, + )?; + + let mut outputs = Vec::with_capacity(requests.len()); + for (i, req) in requests.iter().enumerate() { + let token = tokens[i]; + let logprob = if req.logprobs > 0 { + let logits_i = ops::extract_vec(lane.model.device_ctx(), &lane.bufs.logits, i)?; + Some(lane.extract_logprobs(&logits_i, token, req.logprobs)?) + } else { + None + }; + outputs.push(DecodeRequestResult { + request_id: req.request_id, + token, + logprob, + }); + } + Ok(outputs) +} + fn execute_step_on_lane( lane: &mut LocalQwen3Lane, step: &StepCommand, @@ -223,11 +260,8 @@ fn execute_step_on_lane( .collect(); lane.execute_decode(&token_ids, kv_views, &lora_adapters)?; if collect_result { - let logits: Vec = (0..requests.len()) - .map(|i| ops::extract_vec(lane.model.device_ctx(), &lane.bufs.logits, i)) - .collect::>>()?; Ok(WorkerStepOutcome::Decode(DecodeResult { - requests: build_decode_request_results(lane, requests, &logits)?, + requests: build_batch_decode_request_results(lane, requests)?, })) } else { Ok(WorkerStepOutcome::Ack) @@ -293,23 +327,25 @@ impl Drop for CublasThreadGuard { } struct SamplingScratch { + row_indices: cudarc::driver::CudaSlice, probs: cudarc::driver::CudaSlice, - top1_value: cudarc::driver::CudaSlice, + top1_values: cudarc::driver::CudaSlice, row_states: cudarc::driver::CudaSlice, valid: cudarc::driver::CudaSlice, out: cudarc::driver::CudaSlice, } impl SamplingScratch { - fn new(ctx: &DeviceContext, vocab_size: usize) -> Result { + fn new(ctx: &DeviceContext, vocab_size: usize, max_batch_bucket: usize) -> Result { Ok(Self { + row_indices: ctx.stream.alloc_zeros(max_batch_bucket)?, probs: ctx.stream.alloc_zeros(vocab_size)?, - top1_value: ctx.stream.alloc_zeros(1)?, + top1_values: ctx.stream.alloc_zeros(max_batch_bucket)?, row_states: ctx .stream .alloc_zeros(pegainfer_core::ops::flashinfer_topk_row_states_bytes())?, valid: ctx.stream.alloc_zeros(1)?, - out: ctx.stream.alloc_zeros(1)?, + out: ctx.stream.alloc_zeros(max_batch_bucket)?, }) } } @@ -1216,7 +1252,8 @@ impl LocalQwen3Lane { padding_block_id, model.local_num_attention_heads(), )?; - let sample_scratch = SamplingScratch::new(model.device_ctx(), model.config().vocab_size)?; + let sample_scratch = + SamplingScratch::new(model.device_ctx(), model.config().vocab_size, max_bucket)?; Ok(Self { model, kv_buffer, @@ -1241,7 +1278,7 @@ impl LocalQwen3Lane { self.model.device_ctx(), logits, &mut self.sample_scratch.probs, - &mut self.sample_scratch.top1_value, + &mut self.sample_scratch.top1_values, &mut self.sample_scratch.row_states, &mut self.sample_scratch.valid, &mut self.sample_scratch.out, diff --git a/pegainfer-qwen35-4b/src/batch_decode.rs b/pegainfer-qwen35-4b/src/batch_decode.rs index b809c14b..689362d9 100644 --- a/pegainfer-qwen35-4b/src/batch_decode.rs +++ b/pegainfer-qwen35-4b/src/batch_decode.rs @@ -18,26 +18,19 @@ impl Qwen35Model { params: &[&pegainfer_core::sampler::SamplingParams], rng: &mut rand::rngs::StdRng, ) -> Result> { - let batch_size = params.len(); - - let mut tokens = Vec::with_capacity(batch_size); - for (i, params_i) in params.iter().enumerate().take(batch_size) { - let logits_i = ops::extract_vec(&self.ctx, &bufs.logits, i)?; - let random_val: f32 = rand::RngExt::random(rng); - let token = ops::gpu_sample_into( - &self.ctx, - &logits_i, - &mut bufs.sample_probs, - &mut bufs.sample_top1_value, - &mut bufs.sample_row_states, - &mut bufs.sample_valid, - &mut bufs.sample_out, - params_i, - random_val, - )?; - tokens.push(token); - } - Ok(tokens) + let random_vals: Vec = params.iter().map(|_| rand::RngExt::random(rng)).collect(); + ops::select_batch_tokens_into( + &self.ctx, + &bufs.logits, + params, + &random_vals, + &mut bufs.sample_row_indices, + &mut bufs.sample_probs, + &mut bufs.sample_top1_value, + &mut bufs.sample_row_states, + &mut bufs.sample_valid, + &mut bufs.sample_out, + ) } fn batch_decode_full_attention( diff --git a/pegainfer-qwen35-4b/src/decode_buffers.rs b/pegainfer-qwen35-4b/src/decode_buffers.rs index bba21309..2aa6fd2c 100644 --- a/pegainfer-qwen35-4b/src/decode_buffers.rs +++ b/pegainfer-qwen35-4b/src/decode_buffers.rs @@ -56,6 +56,7 @@ pub(crate) struct BatchDecodeBuffers35 { pub(crate) kv_chunk_size_d: CudaSlice, // Sampling scratch + pub(crate) sample_row_indices: CudaSlice, pub(crate) sample_probs: CudaSlice, pub(crate) sample_top1_value: CudaSlice, pub(crate) sample_row_states: CudaSlice, @@ -126,8 +127,9 @@ impl BatchDecodeBuffers35 { kv_tile_indices_d: ctx.stream.alloc_zeros(bs)?, kv_chunk_size_d: ctx.stream.alloc_zeros(bs)?, + sample_row_indices: ctx.stream.alloc_zeros(bs)?, sample_probs: ctx.stream.alloc_zeros(config.vocab_size)?, - sample_top1_value: ctx.stream.alloc_zeros(1)?, + sample_top1_value: ctx.stream.alloc_zeros(bs)?, sample_row_states: ctx .stream .alloc_zeros(crate::ops::flashinfer_topk_row_states_bytes())?, diff --git a/pegainfer-qwen35-4b/src/ops.rs b/pegainfer-qwen35-4b/src/ops.rs index 2a100192..dcc74c3c 100644 --- a/pegainfer-qwen35-4b/src/ops.rs +++ b/pegainfer-qwen35-4b/src/ops.rs @@ -5,7 +5,8 @@ pub(crate) use pegainfer_core::ops::{ add_batch, add_batch_into, embedding_batch, extract_vec, extract_vec_into, flashinfer_topk_row_states_bytes, gemm, gemm_into, gpu_sample_into, linear, paged_attention_batch_decode_hd256_into, qk_norm_partial_rope_batched_decode_hd256_into, - rms_norm_gated_batch_into, silu_mul_batch, silu_mul_batch_into, write_vec_into, + rms_norm_gated_batch_into, select_batch_tokens_into, silu_mul_batch, silu_mul_batch_into, + write_vec_into, }; pub use pegainfer_core::ops::{rms_norm_batch_offset_into, rms_norm_offset_into}; pub use recurrent::gated_delta_rule_prefill_chunkwise_into;