Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
const int max_block_size =
(num_tokens < 256 && batch_invariant_launch) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -414,7 +416,6 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
Expand Down
4 changes: 1 addition & 3 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#include "../core/batch_invariant.hpp"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -406,8 +405,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
Expand Down
84 changes: 51 additions & 33 deletions tests/v1/generation/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed.
- Keep max_tokens and max_model_len bounded for speed and memory use.
"""
random.seed(12345)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)

# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."

# Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))

# Sampling parameters: longer outputs with a more random-sounding
Expand All @@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with bs=1 behavior
llm_bs1 = LLM_with_max_seqs(
model=model,
max_num_seqs=1,
max_num_seqs=128,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
swap_space=swap_space_gb,
Expand All @@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with larger batch limit (e.g., 64)
llm_bsN = LLM_with_max_seqs(
model=model,
max_num_seqs=batch_size,
max_num_seqs=128,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
swap_space=swap_space_gb,
Expand All @@ -135,15 +138,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
mismatches = 0

for trial in range(num_trials):
# Create a batch of size `batch_size` and insert the needle at
# Create a batch of size `max_batch_size` and insert the needle at
# a random index
prompts: list[str] = []
batch_size = random.randint(max_batch_size // 2, max_batch_size)
needle_pos = random.randint(0, batch_size - 1)
for i in range(batch_size):
if i == needle_pos:
prompts.append(needle_prompt)
else:
prompts.append(_random_prompt())
prompts.append(
_random_prompt(min_random_prompt, max_random_prompt))

# Generate with the larger-batch engine
outputs = llm_bsN.generate(prompts, sampling)
Expand All @@ -154,17 +159,19 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
text = needle_output.outputs[0].text

if text != baseline_text:
print(
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
mismatches += 1

passes = num_trials - mismatches
# Dump how many passed vs failed
print(f"[determinism] total={num_trials}, passed={passes}, "
f"failed={mismatches}, batch_size={batch_size}")
f"failed={mismatches}, max_batch_size={max_batch_size}")

if mismatches > 0:
pytest.fail(
f"Nondeterministic outputs detected: {mismatches} failed out "
f"of {num_trials} trials (batch_size={batch_size}).")
f"of {num_trials} trials (max_batch_size={max_batch_size}).")

finally:
# Ensure engines are shutdown to free GPU/VRAM across test sessions
Expand Down Expand Up @@ -196,9 +203,10 @@ def _extract_step_logprobs(request_output):
not torch.cuda.is_available(),
reason="Requires CUDA to match production inference path.",
)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():

#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

Expand All @@ -207,25 +215,27 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
model=model_name,
tensor_parallel_size=tp_size,
enforce_eager=True, # helps reduce nondeterminism from some backends
max_num_seqs=128,
max_num_batched_tokens=8192,
gpu_memory_utilization=0.9,
enable_prefix_caching=False,
)

prompts = [
"The capital of France is",
"The capital of Germany is",
]
prompts = [_random_prompt(10, 1024) for i in range(1000)]

sp = SamplingParams(
temperature=0.0,
temperature=0.6,
top_p=1.0,
max_tokens=8,
max_tokens=4,
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
seed=1234,
logprobs=5,
)

# BS=1: run prompts individually and collect logprobs per step.
N_test = 1000
bs1_logprobs_per_prompt = []
for p in prompts:
for p in prompts[:N_test]:
outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs = _extract_step_logprobs(outs[0])
Expand All @@ -234,32 +244,40 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
"enable logprobs return to run this test.")
bs1_logprobs_per_prompt.append(step_logprobs)

# BS=2: run prompts in a batch and collect logprobs per step for each
# BS=N: run prompts in a batch and collect logprobs per step for each
# prompt.
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts)
bs2_logprobs_per_prompt = []
for o in outs_batched:
bsN_logprobs_per_prompt = []
for o in outs_batched[:N_test]:
step_logprobs = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
bs2_logprobs_per_prompt.append(step_logprobs)
bsN_logprobs_per_prompt.append(step_logprobs)

# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.

# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
assert len(logprobs_bs1) == len(logprobs_bs2), (
exact_match_count = 0
total_count = 0
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)):
assert len(logprobs_bs1) == len(logprobs_bsN), (
f"Different number of generation steps for prompt index {i}: "
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)")
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
assert a.shape == b.shape, (
f"Logits shape mismatch at prompt {i}, step {t}: "
f"{a.shape} vs {b.shape}")
# Bitwise exact equality.
assert torch.equal(
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
f"(dtype={a.dtype}, shape={a.shape}).")
if torch.equal(a, b):
exact_match_count += 1
else:
print(f"Bitwise logprobs mismatch at prompt {i}, step {t} "
f"(dtype={a.dtype}, shape={a.shape}).")
total_count += 1
assert exact_match_count == total_count, \
f"only {exact_match_count} / {total_count} matched exactly"


def LLM_with_max_seqs(
Expand Down
126 changes: 125 additions & 1 deletion vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch

import vllm.envs as envs
from vllm.triton_utils import tl, triton


Expand Down Expand Up @@ -492,6 +493,126 @@ def mean_batch_invariant(input,
return result


@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx = tl.program_id(0).to(tl.int64)
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride

# Step 1: Compute sum of squares
sum_sq = 0.0
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols

vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
sq_vals = vals * vals
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))

# Step 2: Compute RMS (root mean square)
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms

# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
output = vals * inv_rms * weight
tl.store(output_row_start_ptr + col_idx, output, mask=mask)


def rms_norm(input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel.

RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight

Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability

Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert input.is_cuda, "Input must be a CUDA tensor"
assert weight.is_cuda, "Weight must be a CUDA tensor"
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input.shape[-1] == weight.shape[0], (
f"Input last dimension ({input.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})")

# Flatten all dimensions except the last one
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()

n_rows, n_cols = input_2d.shape

output = torch.empty_like(input_2d)
BLOCK_SIZE = 1024
grid = (n_rows, )
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)


def rms_norm_batch_invariant(input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.

This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.

Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability

Returns:
RMS normalized tensor
"""
return rms_norm(input, weight, eps=eps)


def linear_batch_invariant(input, weight, bias=None):
output = torch.mm(input, weight.t())
if bias is not None:
output = output + bias
return output


_batch_invariant_MODE = False
_batch_invariant_LIB = None

Expand All @@ -509,6 +630,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::_log_softmax",
_log_softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
Expand Down Expand Up @@ -557,5 +679,7 @@ def vllm_kernel_override_batch_invariant():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
if curr_attn_backend not in ["FLEX_ATTENTION", "FLASHINFER"]:
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
enable_batch_invariant_mode()
Loading