Skip to content

Commit 5ccd770

Browse files
committed
Test more ops for batch invariance
Signed-off-by: Bram Wasti <[email protected]>
1 parent 68024d1 commit 5ccd770

File tree

1 file changed

+341
-0
lines changed

1 file changed

+341
-0
lines changed
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import random
4+
from functools import reduce
5+
6+
import pytest
7+
import torch
8+
import torch.multiprocessing as mp
9+
10+
from tests.utils import multi_gpu_test
11+
from vllm.distributed.parallel_state import (init_distributed_environment,
12+
initialize_model_parallel)
13+
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
16+
RowParallelLinear)
17+
from vllm.platforms import current_platform
18+
from vllm.utils import update_environment_variables
19+
20+
21+
def get_open_port():
22+
"""Get an available port for distributed testing."""
23+
import socket
24+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
25+
s.bind(('', 0))
26+
return s.getsockname()[1]
27+
28+
29+
def run_parallel_op_test_worker(local_rank: int, world_size: int,
30+
master_port: int, test_config: dict, fn):
31+
"""Worker function that runs on each GPU process."""
32+
# Set up distributed environment
33+
device = f"cuda:{local_rank}"
34+
current_platform.set_device(device)
35+
torch.cuda.set_device(device)
36+
torch.set_default_device(device)
37+
38+
update_environment_variables({
39+
'RANK': str(local_rank),
40+
'LOCAL_RANK': str(local_rank),
41+
'WORLD_SIZE': str(world_size),
42+
'MASTER_ADDR': 'localhost',
43+
'MASTER_PORT': str(master_port),
44+
})
45+
46+
# Initialize distributed
47+
init_distributed_environment()
48+
initialize_model_parallel(tensor_model_parallel_size=world_size)
49+
50+
# Set seed for reproducibility
51+
current_platform.seed_everything(42)
52+
init_batch_invariance()
53+
54+
# Run the specific test based on op_name
55+
fn(local_rank, world_size, test_config)
56+
57+
58+
class ULPChecker:
59+
FP_SPECS = {
60+
torch.float8_e4m3fn: {
61+
'mantissa_bits': 3,
62+
'exponent_bits': 4,
63+
'total_bits': 8,
64+
'int_dtype': torch.uint8
65+
},
66+
torch.float8_e5m2: {
67+
'mantissa_bits': 2,
68+
'exponent_bits': 5,
69+
'total_bits': 8,
70+
'int_dtype': torch.uint8
71+
},
72+
torch.bfloat16: {
73+
'mantissa_bits': 7,
74+
'exponent_bits': 8,
75+
'total_bits': 16,
76+
'int_dtype': torch.int16
77+
},
78+
torch.float16: {
79+
'mantissa_bits': 10,
80+
'exponent_bits': 5,
81+
'total_bits': 16,
82+
'int_dtype': torch.int16
83+
},
84+
torch.float32: {
85+
'mantissa_bits': 23,
86+
'exponent_bits': 8,
87+
'total_bits': 32,
88+
'int_dtype': torch.int32
89+
},
90+
torch.float64: {
91+
'mantissa_bits': 52,
92+
'exponent_bits': 11,
93+
'total_bits': 64,
94+
'int_dtype': torch.int64
95+
},
96+
}
97+
98+
@staticmethod
99+
def to_int_bits(tensor: torch.Tensor) -> torch.Tensor:
100+
dtype = tensor.dtype
101+
if dtype not in ULPChecker.FP_SPECS:
102+
raise ValueError(f"Unsupported dtype: {dtype}")
103+
104+
spec = ULPChecker.FP_SPECS[dtype]
105+
int_dtype = spec['int_dtype']
106+
107+
return tensor.view(int_dtype)
108+
109+
@staticmethod
110+
def ulp_distance_int(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
111+
if a.dtype != b.dtype:
112+
raise ValueError(f"Dtype mismatch: {a.dtype} vs {b.dtype}")
113+
114+
if a.shape != b.shape:
115+
raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")
116+
117+
spec = ULPChecker.FP_SPECS[a.dtype]
118+
total_bits = spec['total_bits']
119+
120+
a_int = ULPChecker.to_int_bits(a)
121+
b_int = ULPChecker.to_int_bits(b)
122+
123+
sign_bit = 1 << (total_bits - 1)
124+
125+
a_ordered = torch.where(
126+
(a_int & sign_bit) != 0,
127+
sign_bit - (a_int & ~sign_bit), # Negative: flip magnitude bits
128+
a_int + sign_bit # Positive: offset by sign bit
129+
)
130+
b_ordered = torch.where((b_int & sign_bit) != 0,
131+
sign_bit - (b_int & ~sign_bit),
132+
b_int + sign_bit)
133+
134+
ulp_dist = torch.abs(a_ordered - b_ordered)
135+
return ulp_dist
136+
137+
138+
def create_needle_tensor(
139+
batch_size: int,
140+
shape: list[int],
141+
device: torch.device,
142+
dtype: torch.dtype,
143+
needle_idx: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
144+
input_tensor = torch.randn(batch_size, *shape, device=device, dtype=dtype)
145+
146+
numel = reduce(lambda x, y: x * y, shape)
147+
needle_pattern = torch.sin(
148+
torch.arange(numel, device=device).float().view(*shape) *
149+
0.1).to(dtype)
150+
151+
assert needle_idx < input_tensor.shape[0]
152+
input_tensor[needle_idx] = needle_pattern
153+
154+
return input_tensor
155+
156+
157+
def verify_needle_consistency(outputs: list[torch.Tensor],
158+
needle_idxs: list[int]) -> bool:
159+
if len(outputs) < 2:
160+
return True
161+
162+
needle_outputs = []
163+
for output, needle_idx in zip(outputs, needle_idxs):
164+
needle_outputs.append(output[needle_idx])
165+
166+
reference = needle_outputs[0]
167+
for i, needle_output in enumerate(needle_outputs[1:], 1):
168+
dist_t = ULPChecker.ulp_distance_int(reference, needle_output)
169+
if torch.max(dist_t) != 0:
170+
print(f"Needle consistency failed at batch size comparison {i}")
171+
print(f"Max difference (ULP): {torch.max(dist_t)}")
172+
print(f"Max difference: {torch.max(reference - needle_output)}")
173+
return False
174+
175+
return True
176+
177+
178+
def validate(func, batch_sizes, shape, device, dtype):
179+
random.seed(123)
180+
outputs = []
181+
needle_idxs = []
182+
183+
for batch_size in batch_sizes:
184+
needle_idx = random.randint(0, batch_size - 1)
185+
input_tensor = create_needle_tensor(batch_size, shape, device, dtype,
186+
needle_idx)
187+
188+
with torch.no_grad():
189+
output = func(input_tensor)
190+
assert isinstance(output, torch.Tensor)
191+
outputs.append(output)
192+
needle_idxs.append(needle_idx)
193+
194+
assert verify_needle_consistency(outputs, needle_idxs), \
195+
"Needle consistency failed"
196+
197+
198+
def _test_column_parallel_linear(local_rank: int, world_size: int,
199+
config: dict):
200+
device = torch.device(f"cuda:{local_rank}")
201+
batch_sizes = [1, 8, 32]
202+
dtype = config['dtype']
203+
hidden_size = config['reduction_size']
204+
seq_len = 4096
205+
input_size = hidden_size
206+
output_size = hidden_size * 2
207+
layer = ColumnParallelLinear(
208+
input_size=input_size,
209+
output_size=output_size,
210+
bias=True,
211+
gather_output=False,
212+
params_dtype=dtype,
213+
)
214+
layer = layer.to(device)
215+
validate(lambda x: layer(x)[0], batch_sizes, (seq_len, hidden_size),
216+
device, dtype)
217+
218+
219+
def _test_row_parallel_linear(local_rank: int, world_size: int, config: dict):
220+
device = torch.device(f"cuda:{local_rank}")
221+
batch_sizes = [1, 8, 32]
222+
dtype = config['dtype']
223+
hidden_size = config['reduction_size']
224+
seq_len = 4096
225+
input_size = hidden_size * 2
226+
output_size = hidden_size
227+
layer = RowParallelLinear(
228+
input_size=input_size,
229+
output_size=output_size,
230+
bias=True,
231+
reduce_results=True,
232+
params_dtype=dtype,
233+
)
234+
layer = layer.to(device)
235+
validate(lambda x: layer(x)[0], batch_sizes,
236+
(seq_len, input_size // world_size), device, dtype)
237+
238+
239+
def _test_rms_norm_needle_consistency(local_rank: int, world_size: int,
240+
config: dict):
241+
"""Test RMSNorm with needle consistency."""
242+
device = torch.device(f"cuda:{local_rank}")
243+
dtype = config['dtype']
244+
hidden_size = config['reduction_size']
245+
batch_sizes = [1, 32, 1024]
246+
247+
layer = RMSNorm(hidden_size, eps=1e-6)
248+
layer = layer.to(device).to(dtype)
249+
validate(layer, batch_sizes, (hidden_size, ), device, dtype)
250+
251+
252+
def _test_fused_rms_norm_needle_consistency(local_rank: int, world_size: int,
253+
config: dict):
254+
device = torch.device(f"cuda:{local_rank}")
255+
dtype = config['dtype']
256+
hidden_size = config['reduction_size']
257+
batch_sizes = [1, 32, 1024]
258+
259+
layer = RMSNorm(hidden_size, eps=1e-6)
260+
layer = layer.to(device).to(dtype)
261+
validate(lambda x: layer(x, x)[0], batch_sizes, (hidden_size, ), device,
262+
dtype)
263+
264+
265+
def _test_fused_moe_needle_consistency(local_rank: int, world_size: int,
266+
config: dict):
267+
"""Test FusedMoE with needle consistency."""
268+
device = torch.device(f"cuda:{local_rank}")
269+
dtype = config['dtype']
270+
hidden_size = config['reduction_size']
271+
batch_sizes = [1, 8, 32]
272+
273+
# MoE configuration parameters
274+
num_experts = 8
275+
top_k = 2
276+
intermediate_size = hidden_size * 4
277+
278+
from vllm.config import VllmConfig
279+
from vllm.forward_context import get_forward_context, set_forward_context
280+
from vllm.model_executor.layers.fused_moe import FusedMoE
281+
282+
vllm_config = VllmConfig()
283+
284+
# Create FusedMoE layer similar to how it's used in models
285+
layer = FusedMoE(
286+
num_experts=num_experts,
287+
top_k=top_k,
288+
hidden_size=hidden_size,
289+
intermediate_size=intermediate_size,
290+
params_dtype=dtype,
291+
reduce_results=True,
292+
renormalize=True,
293+
use_grouped_topk=False,
294+
)
295+
layer = layer.to(device)
296+
297+
# Test function that takes hidden states and generates router logits
298+
def test_func(hidden_states):
299+
# Generate router logits (this would normally come from a router layer)
300+
router_logits = torch.randn(hidden_states.shape[0],
301+
hidden_states.shape[1],
302+
num_experts,
303+
device=device,
304+
dtype=dtype)
305+
306+
# Set forward context with minimal required parameters
307+
# attn_metadata can be None for testing purposes
308+
with set_forward_context(attn_metadata=None,
309+
vllm_config=vllm_config,
310+
num_tokens=hidden_states.shape[0] *
311+
hidden_states.shape[1]):
312+
fwdctx = get_forward_context()
313+
fwdctx.no_compile_layers[''] = layer
314+
return layer(hidden_states, router_logits)
315+
316+
validate(test_func, batch_sizes, (hidden_size, ), device, dtype)
317+
318+
319+
@multi_gpu_test(num_gpus=2)
320+
@pytest.mark.parametrize("world_size", [2])
321+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
322+
@pytest.mark.parametrize("reduction_size", [1, 5, 1024, 1024 + 1])
323+
@pytest.mark.parametrize("func", [
324+
_test_column_parallel_linear,
325+
_test_row_parallel_linear,
326+
_test_rms_norm_needle_consistency,
327+
_test_fused_rms_norm_needle_consistency,
328+
_test_fused_moe_needle_consistency,
329+
])
330+
def test_parallel_reduction_batch_invariance(world_size: int,
331+
dtype: torch.dtype,
332+
reduction_size: int, func):
333+
"""Test parallel operators on 2 GPUs."""
334+
test_config = {
335+
"dtype": dtype,
336+
"reduction_size": reduction_size,
337+
}
338+
339+
mp.spawn(run_parallel_op_test_worker,
340+
args=(world_size, get_open_port(), test_config, func),
341+
nprocs=world_size)

0 commit comments

Comments
 (0)