Skip to content

Commit c50f2f7

Browse files
committed
moe working
1 parent 5114e61 commit c50f2f7

File tree

6 files changed

+162
-34
lines changed

6 files changed

+162
-34
lines changed

csrc/layernorm_kernels.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
391391
When num_tokens is large, a smaller block size allows
392392
for increased block occupancy on CUs and better latency
393393
hiding on global mem ops. */
394-
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
394+
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
395+
const int max_block_size =
396+
(num_tokens < 256 && batch_invariant_launch) ? 1024 : 256;
395397
dim3 block(std::min(hidden_size, max_block_size));
396398
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
397399
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -414,7 +416,6 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
414416
wt_ptr % req_alignment_bytes == 0;
415417
bool offsets_are_multiple_of_vector_width =
416418
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
417-
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
418419
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
419420
!batch_invariant_launch) {
420421
LAUNCH_FUSED_ADD_RMS_NORM(8);

csrc/moe/topk_softmax_kernels.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <c10/cuda/CUDAGuard.h>
2222
#include "../cuda_compat.h"
2323
#include "../cub_helpers.h"
24-
#include "../core/batch_invariant.hpp"
2524

2625
#define MAX(a, b) ((a) > (b) ? (a) : (b))
2726
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -406,8 +405,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
406405
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
407406
static constexpr int VPT = Constants::VPT;
408407
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
409-
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
410-
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
408+
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
411409
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
412410

413411
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);

tests/v1/generation/batch_invariance/test_multi_gpu_ops.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,11 @@ def ulp_distance_int(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
135135
return ulp_dist
136136

137137

138-
def create_needle_tensor(
139-
batch_size: int,
140-
shape: list[int],
141-
device: torch.device,
142-
dtype: torch.dtype,
143-
needle_idx: int = 0) -> torch.Tensor:
138+
def create_needle_tensor(batch_size: int,
139+
shape: list[int],
140+
device: torch.device,
141+
dtype: torch.dtype,
142+
needle_idx: int = 0) -> torch.Tensor:
144143
input_tensor = torch.randn(batch_size, *shape, device=device, dtype=dtype)
145144

146145
numel = reduce(lambda x, y: x * y, shape)
@@ -154,8 +153,7 @@ def create_needle_tensor(
154153
return input_tensor
155154

156155

157-
def verify(outputs: list[torch.Tensor],
158-
needle_idxs: list[int]) -> bool:
156+
def verify(outputs: list[torch.Tensor], needle_idxs: list[int]) -> bool:
159157
if len(outputs) < 2:
160158
return True
161159

@@ -236,8 +234,7 @@ def _test_row_parallel_linear(local_rank: int, world_size: int, config: dict):
236234
(seq_len, input_size // world_size), device, dtype)
237235

238236

239-
def _test_rms_norm(local_rank: int, world_size: int,
240-
config: dict):
237+
def _test_rms_norm(local_rank: int, world_size: int, config: dict):
241238
"""Test RMSNorm with needle consistency."""
242239
device = torch.device(f"cuda:{local_rank}")
243240
dtype = config['dtype']
@@ -249,8 +246,7 @@ def _test_rms_norm(local_rank: int, world_size: int,
249246
validate(layer, batch_sizes, (hidden_size, ), device, dtype)
250247

251248

252-
def _test_fused_rms_norm(local_rank: int, world_size: int,
253-
config: dict):
249+
def _test_fused_rms_norm(local_rank: int, world_size: int, config: dict):
254250
device = torch.device(f"cuda:{local_rank}")
255251
dtype = config['dtype']
256252
hidden_size = config['reduction_size']
@@ -262,8 +258,7 @@ def _test_fused_rms_norm(local_rank: int, world_size: int,
262258
dtype)
263259

264260

265-
def _test_fused_moe(local_rank: int, world_size: int,
266-
config: dict):
261+
def _test_fused_moe(local_rank: int, world_size: int, config: dict):
267262
"""Test FusedMoE with needle consistency."""
268263
device = torch.device(f"cuda:{local_rank}")
269264
dtype = config['dtype']

tests/v1/generation/test_batch_invariance.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,30 +215,27 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
215215
model=model_name,
216216
tensor_parallel_size=tp_size,
217217
enforce_eager=True, # helps reduce nondeterminism from some backends
218+
max_num_seqs=128,
219+
max_num_batched_tokens=8192,
220+
gpu_memory_utilization=0.9,
221+
enable_prefix_caching=False,
218222
)
219223

220-
prompts = [
221-
"The capital of France is",
222-
"The capital of Germany is",
223-
_random_prompt(10, 1024),
224-
_random_prompt(10, 1024),
225-
_random_prompt(10, 1024),
226-
_random_prompt(10, 1024),
227-
_random_prompt(10, 1024),
228-
]
224+
prompts = [_random_prompt(10, 1024) for i in range(1000)]
229225

230226
sp = SamplingParams(
231227
temperature=0.6,
232228
top_p=1.0,
233-
max_tokens=8,
229+
max_tokens=4,
234230
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
235231
seed=1234,
236232
logprobs=5,
237233
)
238234

239235
# BS=1: run prompts individually and collect logprobs per step.
236+
N_test = 1000
240237
bs1_logprobs_per_prompt = []
241-
for p in prompts:
238+
for p in prompts[:N_test]:
242239
outs = llm.generate([p], sp, use_tqdm=False)
243240
assert len(outs) == 1
244241
step_logprobs = _extract_step_logprobs(outs[0])
@@ -252,14 +249,18 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
252249
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
253250
assert len(outs_batched) == len(prompts)
254251
bsN_logprobs_per_prompt = []
255-
for o in outs_batched:
252+
for o in outs_batched[:N_test]:
256253
step_logprobs = _extract_step_logprobs(o)
257254
if step_logprobs is None:
258255
pytest.skip("Logits are not available on RequestOutput; "
259256
"enable logprobs return to run this test.")
260257
bsN_logprobs_per_prompt.append(step_logprobs)
261258

262259
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
260+
261+
exact_match_count = 0
262+
total_count = 0
263+
mismatch_stats = []
263264
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
264265
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)):
265266
assert len(logprobs_bs1) == len(logprobs_bsN), (
@@ -270,9 +271,14 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
270271
f"Logits shape mismatch at prompt {i}, step {t}: "
271272
f"{a.shape} vs {b.shape}")
272273
# Bitwise exact equality.
273-
assert torch.equal(
274-
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
275-
f"(dtype={a.dtype}, shape={a.shape}).")
274+
if torch.equal(a, b):
275+
exact_match_count += 1
276+
else:
277+
mismatch_stats.append((i, t)) #(len(prompts[i])))
278+
print(f"Bitwise logprobs mismatch at prompt {i}, step {t} "
279+
f"(dtype={a.dtype}, shape={a.shape}).")
280+
total_count += 1
281+
assert exact_match_count == total_count, f"only {exact_match_count} / {total_count} matched exactly {mismatch_stats}"
276282

277283

278284
def LLM_with_max_seqs(

vllm/model_executor/layers/batch_invariant.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,126 @@ def mean_batch_invariant(input,
493493
return result
494494

495495

496+
@triton.jit
497+
def _rms_norm_kernel(
498+
input_ptr,
499+
weight_ptr,
500+
output_ptr,
501+
input_row_stride,
502+
output_row_stride,
503+
n_cols,
504+
eps,
505+
BLOCK_SIZE: tl.constexpr,
506+
):
507+
"""
508+
Compute RMS normalization along the last dimension of a 2D tensor.
509+
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
510+
Each block handles one row of the input tensor.
511+
"""
512+
row_idx = tl.program_id(0).to(tl.int64)
513+
row_start_ptr = input_ptr + row_idx * input_row_stride
514+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
515+
516+
# Step 1: Compute sum of squares
517+
sum_sq = 0.0
518+
for col_offset in range(0, n_cols, BLOCK_SIZE):
519+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
520+
mask = col_idx < n_cols
521+
522+
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
523+
sq_vals = vals * vals
524+
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
525+
526+
# Step 2: Compute RMS (root mean square)
527+
mean_sq = sum_sq / n_cols
528+
rms = tl.sqrt(mean_sq + eps)
529+
inv_rms = 1.0 / rms
530+
531+
# Step 3: Normalize and apply weight
532+
for col_offset in range(0, n_cols, BLOCK_SIZE):
533+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
534+
mask = col_idx < n_cols
535+
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
536+
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
537+
output = vals * inv_rms * weight
538+
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
539+
540+
541+
def rms_norm(input: torch.Tensor,
542+
weight: torch.Tensor,
543+
eps: float = 1e-6) -> torch.Tensor:
544+
"""
545+
Compute RMS normalization using Triton kernel.
546+
547+
RMS Norm normalizes the input by the root mean square and scales by weight:
548+
output = input / sqrt(mean(input^2) + eps) * weight
549+
550+
Args:
551+
input: Input tensor of shape (..., hidden_size)
552+
weight: Weight tensor of shape (hidden_size,)
553+
eps: Small constant for numerical stability
554+
555+
Returns:
556+
Tensor with RMS normalization applied along the last dimension
557+
"""
558+
assert input.is_cuda, "Input must be a CUDA tensor"
559+
assert weight.is_cuda, "Weight must be a CUDA tensor"
560+
assert weight.dim() == 1, "Weight must be 1-dimensional"
561+
assert input.shape[-1] == weight.shape[0], (
562+
f"Input last dimension ({input.shape[-1]}) must match "
563+
f"weight dimension ({weight.shape[0]})")
564+
565+
# Flatten all dimensions except the last one
566+
original_shape = input.shape
567+
input_2d = input.reshape(-1, input.shape[-1])
568+
input_2d = input_2d.contiguous()
569+
weight = weight.contiguous()
570+
571+
n_rows, n_cols = input_2d.shape
572+
573+
output = torch.empty_like(input_2d)
574+
BLOCK_SIZE = 1024
575+
grid = (n_rows, )
576+
_rms_norm_kernel[grid](
577+
input_2d,
578+
weight,
579+
output,
580+
input_2d.stride(0),
581+
output.stride(0),
582+
n_cols,
583+
eps,
584+
BLOCK_SIZE=BLOCK_SIZE,
585+
)
586+
return output.reshape(original_shape)
587+
588+
589+
def rms_norm_batch_invariant(input: torch.Tensor,
590+
weight: torch.Tensor,
591+
eps: float = 1e-6) -> torch.Tensor:
592+
"""
593+
Batch-invariant wrapper for RMS normalization.
594+
595+
This function provides a deterministic, batch-invariant implementation
596+
of RMS normalization for use with the batch_invariant mode.
597+
598+
Args:
599+
input: Input tensor of shape (..., hidden_size)
600+
weight: Weight tensor of shape (hidden_size,)
601+
eps: Small constant for numerical stability
602+
603+
Returns:
604+
RMS normalized tensor
605+
"""
606+
return rms_norm(input, weight, eps=eps)
607+
608+
609+
def linear_batch_invariant(input, weight, bias=None):
610+
output = torch.mm(input, weight.t())
611+
if bias is not None:
612+
output = output + bias
613+
return output
614+
615+
496616
_batch_invariant_MODE = False
497617
_batch_invariant_LIB = None
498618

@@ -510,6 +630,7 @@ def enable_batch_invariant_mode():
510630
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
511631
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
512632
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
633+
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
513634
_batch_invariant_LIB.impl("aten::_log_softmax",
514635
_log_softmax_batch_invariant, "CUDA")
515636
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")

vllm/model_executor/layers/layernorm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import vllm.envs as envs
1010
from vllm.model_executor.custom_op import CustomOp
11+
from vllm.model_executor.layers.batch_invariant import (
12+
rms_norm_batch_invariant, vllm_kernel_override_batch_invariant)
1113
from vllm.platforms import current_platform
1214
from vllm.utils import direct_register_custom_op
1315

@@ -19,6 +21,8 @@ def is_rocm_aiter_rmsnorm_enabled() -> bool:
1921

2022
def rms_norm(x: torch.Tensor, weight: torch.Tensor,
2123
variance_epsilon: float) -> torch.Tensor:
24+
if vllm_kernel_override_batch_invariant():
25+
return rms_norm_batch_invariant(x, weight, variance_epsilon)
2226
from vllm import _custom_ops as ops
2327
out = torch.empty_like(x)
2428
ops.rms_norm(
@@ -33,6 +37,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
3337
def fused_add_rms_norm(
3438
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
3539
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
40+
if vllm_kernel_override_batch_invariant():
41+
return residual + rms_norm_batch_invariant(x, weight,
42+
variance_epsilon), residual
3643
from vllm import _custom_ops as ops
3744
ops.fused_add_rms_norm(
3845
x,

0 commit comments

Comments
 (0)