|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import random |
| 4 | +from functools import reduce |
| 5 | + |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | +import torch.multiprocessing as mp |
| 9 | + |
| 10 | +from tests.utils import multi_gpu_test |
| 11 | +from vllm.distributed.parallel_state import (init_distributed_environment, |
| 12 | + initialize_model_parallel) |
| 13 | +from vllm.model_executor.layers.batch_invariant import init_batch_invariance |
| 14 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 15 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 16 | + RowParallelLinear) |
| 17 | +from vllm.platforms import current_platform |
| 18 | +from vllm.utils import update_environment_variables |
| 19 | + |
| 20 | + |
| 21 | +def get_open_port(): |
| 22 | + """Get an available port for distributed testing.""" |
| 23 | + import socket |
| 24 | + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| 25 | + s.bind(('', 0)) |
| 26 | + return s.getsockname()[1] |
| 27 | + |
| 28 | + |
| 29 | +def run_parallel_op_test_worker(local_rank: int, world_size: int, |
| 30 | + master_port: int, test_config: dict, fn): |
| 31 | + """Worker function that runs on each GPU process.""" |
| 32 | + # Set up distributed environment |
| 33 | + device = f"cuda:{local_rank}" |
| 34 | + current_platform.set_device(device) |
| 35 | + torch.cuda.set_device(device) |
| 36 | + torch.set_default_device(device) |
| 37 | + |
| 38 | + update_environment_variables({ |
| 39 | + 'RANK': str(local_rank), |
| 40 | + 'LOCAL_RANK': str(local_rank), |
| 41 | + 'WORLD_SIZE': str(world_size), |
| 42 | + 'MASTER_ADDR': 'localhost', |
| 43 | + 'MASTER_PORT': str(master_port), |
| 44 | + }) |
| 45 | + |
| 46 | + # Initialize distributed |
| 47 | + init_distributed_environment() |
| 48 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 49 | + |
| 50 | + # Set seed for reproducibility |
| 51 | + current_platform.seed_everything(42) |
| 52 | + init_batch_invariance() |
| 53 | + |
| 54 | + # Run the specific test based on op_name |
| 55 | + fn(local_rank, world_size, test_config) |
| 56 | + |
| 57 | + |
| 58 | +class ULPChecker: |
| 59 | + FP_SPECS = { |
| 60 | + torch.float8_e4m3fn: { |
| 61 | + 'mantissa_bits': 3, |
| 62 | + 'exponent_bits': 4, |
| 63 | + 'total_bits': 8, |
| 64 | + 'int_dtype': torch.uint8 |
| 65 | + }, |
| 66 | + torch.float8_e5m2: { |
| 67 | + 'mantissa_bits': 2, |
| 68 | + 'exponent_bits': 5, |
| 69 | + 'total_bits': 8, |
| 70 | + 'int_dtype': torch.uint8 |
| 71 | + }, |
| 72 | + torch.bfloat16: { |
| 73 | + 'mantissa_bits': 7, |
| 74 | + 'exponent_bits': 8, |
| 75 | + 'total_bits': 16, |
| 76 | + 'int_dtype': torch.int16 |
| 77 | + }, |
| 78 | + torch.float16: { |
| 79 | + 'mantissa_bits': 10, |
| 80 | + 'exponent_bits': 5, |
| 81 | + 'total_bits': 16, |
| 82 | + 'int_dtype': torch.int16 |
| 83 | + }, |
| 84 | + torch.float32: { |
| 85 | + 'mantissa_bits': 23, |
| 86 | + 'exponent_bits': 8, |
| 87 | + 'total_bits': 32, |
| 88 | + 'int_dtype': torch.int32 |
| 89 | + }, |
| 90 | + torch.float64: { |
| 91 | + 'mantissa_bits': 52, |
| 92 | + 'exponent_bits': 11, |
| 93 | + 'total_bits': 64, |
| 94 | + 'int_dtype': torch.int64 |
| 95 | + }, |
| 96 | + } |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def to_int_bits(tensor: torch.Tensor) -> torch.Tensor: |
| 100 | + dtype = tensor.dtype |
| 101 | + if dtype not in ULPChecker.FP_SPECS: |
| 102 | + raise ValueError(f"Unsupported dtype: {dtype}") |
| 103 | + |
| 104 | + spec = ULPChecker.FP_SPECS[dtype] |
| 105 | + int_dtype = spec['int_dtype'] |
| 106 | + |
| 107 | + return tensor.view(int_dtype) |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + def ulp_distance_int(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| 111 | + if a.dtype != b.dtype: |
| 112 | + raise ValueError(f"Dtype mismatch: {a.dtype} vs {b.dtype}") |
| 113 | + |
| 114 | + if a.shape != b.shape: |
| 115 | + raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}") |
| 116 | + |
| 117 | + spec = ULPChecker.FP_SPECS[a.dtype] |
| 118 | + total_bits = spec['total_bits'] |
| 119 | + |
| 120 | + a_int = ULPChecker.to_int_bits(a) |
| 121 | + b_int = ULPChecker.to_int_bits(b) |
| 122 | + |
| 123 | + sign_bit = 1 << (total_bits - 1) |
| 124 | + |
| 125 | + a_ordered = torch.where( |
| 126 | + (a_int & sign_bit) != 0, |
| 127 | + sign_bit - (a_int & ~sign_bit), # Negative: flip magnitude bits |
| 128 | + a_int + sign_bit # Positive: offset by sign bit |
| 129 | + ) |
| 130 | + b_ordered = torch.where((b_int & sign_bit) != 0, |
| 131 | + sign_bit - (b_int & ~sign_bit), |
| 132 | + b_int + sign_bit) |
| 133 | + |
| 134 | + ulp_dist = torch.abs(a_ordered - b_ordered) |
| 135 | + return ulp_dist |
| 136 | + |
| 137 | + |
| 138 | +def create_needle_tensor( |
| 139 | + batch_size: int, |
| 140 | + shape: list[int], |
| 141 | + device: torch.device, |
| 142 | + dtype: torch.dtype, |
| 143 | + needle_idx: int = 0) -> tuple[torch.Tensor, torch.Tensor]: |
| 144 | + input_tensor = torch.randn(batch_size, *shape, device=device, dtype=dtype) |
| 145 | + |
| 146 | + numel = reduce(lambda x, y: x * y, shape) |
| 147 | + needle_pattern = torch.sin( |
| 148 | + torch.arange(numel, device=device).float().view(*shape) * |
| 149 | + 0.1).to(dtype) |
| 150 | + |
| 151 | + assert needle_idx < input_tensor.shape[0] |
| 152 | + input_tensor[needle_idx] = needle_pattern |
| 153 | + |
| 154 | + return input_tensor |
| 155 | + |
| 156 | + |
| 157 | +def verify_needle_consistency(outputs: list[torch.Tensor], |
| 158 | + needle_idxs: list[int]) -> bool: |
| 159 | + if len(outputs) < 2: |
| 160 | + return True |
| 161 | + |
| 162 | + needle_outputs = [] |
| 163 | + for output, needle_idx in zip(outputs, needle_idxs): |
| 164 | + needle_outputs.append(output[needle_idx]) |
| 165 | + |
| 166 | + reference = needle_outputs[0] |
| 167 | + for i, needle_output in enumerate(needle_outputs[1:], 1): |
| 168 | + dist_t = ULPChecker.ulp_distance_int(reference, needle_output) |
| 169 | + if torch.max(dist_t) != 0: |
| 170 | + print(f"Needle consistency failed at batch size comparison {i}") |
| 171 | + print(f"Max difference (ULP): {torch.max(dist_t)}") |
| 172 | + print(f"Max difference: {torch.max(reference - needle_output)}") |
| 173 | + return False |
| 174 | + |
| 175 | + return True |
| 176 | + |
| 177 | + |
| 178 | +def validate(func, batch_sizes, shape, device, dtype): |
| 179 | + random.seed(123) |
| 180 | + outputs = [] |
| 181 | + needle_idxs = [] |
| 182 | + |
| 183 | + for batch_size in batch_sizes: |
| 184 | + needle_idx = random.randint(0, batch_size - 1) |
| 185 | + input_tensor = create_needle_tensor(batch_size, shape, device, dtype, |
| 186 | + needle_idx) |
| 187 | + |
| 188 | + with torch.no_grad(): |
| 189 | + output = func(input_tensor) |
| 190 | + assert isinstance(output, torch.Tensor) |
| 191 | + outputs.append(output) |
| 192 | + needle_idxs.append(needle_idx) |
| 193 | + |
| 194 | + assert verify_needle_consistency(outputs, needle_idxs), \ |
| 195 | + "Needle consistency failed" |
| 196 | + |
| 197 | + |
| 198 | +def _test_column_parallel_linear(local_rank: int, world_size: int, |
| 199 | + config: dict): |
| 200 | + device = torch.device(f"cuda:{local_rank}") |
| 201 | + batch_sizes = [1, 8, 32] |
| 202 | + dtype = config['dtype'] |
| 203 | + hidden_size = config['reduction_size'] |
| 204 | + seq_len = 4096 |
| 205 | + input_size = hidden_size |
| 206 | + output_size = hidden_size * 2 |
| 207 | + layer = ColumnParallelLinear( |
| 208 | + input_size=input_size, |
| 209 | + output_size=output_size, |
| 210 | + bias=True, |
| 211 | + gather_output=False, |
| 212 | + params_dtype=dtype, |
| 213 | + ) |
| 214 | + layer = layer.to(device) |
| 215 | + validate(lambda x: layer(x)[0], batch_sizes, (seq_len, hidden_size), |
| 216 | + device, dtype) |
| 217 | + |
| 218 | + |
| 219 | +def _test_row_parallel_linear(local_rank: int, world_size: int, config: dict): |
| 220 | + device = torch.device(f"cuda:{local_rank}") |
| 221 | + batch_sizes = [1, 8, 32] |
| 222 | + dtype = config['dtype'] |
| 223 | + hidden_size = config['reduction_size'] |
| 224 | + seq_len = 4096 |
| 225 | + input_size = hidden_size * 2 |
| 226 | + output_size = hidden_size |
| 227 | + layer = RowParallelLinear( |
| 228 | + input_size=input_size, |
| 229 | + output_size=output_size, |
| 230 | + bias=True, |
| 231 | + reduce_results=True, |
| 232 | + params_dtype=dtype, |
| 233 | + ) |
| 234 | + layer = layer.to(device) |
| 235 | + validate(lambda x: layer(x)[0], batch_sizes, |
| 236 | + (seq_len, input_size // world_size), device, dtype) |
| 237 | + |
| 238 | + |
| 239 | +def _test_rms_norm_needle_consistency(local_rank: int, world_size: int, |
| 240 | + config: dict): |
| 241 | + """Test RMSNorm with needle consistency.""" |
| 242 | + device = torch.device(f"cuda:{local_rank}") |
| 243 | + dtype = config['dtype'] |
| 244 | + hidden_size = config['reduction_size'] |
| 245 | + batch_sizes = [1, 32, 1024] |
| 246 | + |
| 247 | + layer = RMSNorm(hidden_size, eps=1e-6) |
| 248 | + layer = layer.to(device).to(dtype) |
| 249 | + validate(layer, batch_sizes, (hidden_size, ), device, dtype) |
| 250 | + |
| 251 | + |
| 252 | +def _test_fused_rms_norm_needle_consistency(local_rank: int, world_size: int, |
| 253 | + config: dict): |
| 254 | + device = torch.device(f"cuda:{local_rank}") |
| 255 | + dtype = config['dtype'] |
| 256 | + hidden_size = config['reduction_size'] |
| 257 | + batch_sizes = [1, 32, 1024] |
| 258 | + |
| 259 | + layer = RMSNorm(hidden_size, eps=1e-6) |
| 260 | + layer = layer.to(device).to(dtype) |
| 261 | + validate(lambda x: layer(x, x)[0], batch_sizes, (hidden_size, ), device, |
| 262 | + dtype) |
| 263 | + |
| 264 | + |
| 265 | +def _test_fused_moe_needle_consistency(local_rank: int, world_size: int, |
| 266 | + config: dict): |
| 267 | + """Test FusedMoE with needle consistency.""" |
| 268 | + device = torch.device(f"cuda:{local_rank}") |
| 269 | + dtype = config['dtype'] |
| 270 | + hidden_size = config['reduction_size'] |
| 271 | + batch_sizes = [1, 8, 32] |
| 272 | + |
| 273 | + # MoE configuration parameters |
| 274 | + num_experts = 8 |
| 275 | + top_k = 2 |
| 276 | + intermediate_size = hidden_size * 4 |
| 277 | + |
| 278 | + from vllm.config import VllmConfig |
| 279 | + from vllm.forward_context import get_forward_context, set_forward_context |
| 280 | + from vllm.model_executor.layers.fused_moe import FusedMoE |
| 281 | + |
| 282 | + vllm_config = VllmConfig() |
| 283 | + |
| 284 | + # Create FusedMoE layer similar to how it's used in models |
| 285 | + layer = FusedMoE( |
| 286 | + num_experts=num_experts, |
| 287 | + top_k=top_k, |
| 288 | + hidden_size=hidden_size, |
| 289 | + intermediate_size=intermediate_size, |
| 290 | + params_dtype=dtype, |
| 291 | + reduce_results=True, |
| 292 | + renormalize=True, |
| 293 | + use_grouped_topk=False, |
| 294 | + ) |
| 295 | + layer = layer.to(device) |
| 296 | + |
| 297 | + # Test function that takes hidden states and generates router logits |
| 298 | + def test_func(hidden_states): |
| 299 | + # Generate router logits (this would normally come from a router layer) |
| 300 | + router_logits = torch.randn(hidden_states.shape[0], |
| 301 | + hidden_states.shape[1], |
| 302 | + num_experts, |
| 303 | + device=device, |
| 304 | + dtype=dtype) |
| 305 | + |
| 306 | + # Set forward context with minimal required parameters |
| 307 | + # attn_metadata can be None for testing purposes |
| 308 | + with set_forward_context(attn_metadata=None, |
| 309 | + vllm_config=vllm_config, |
| 310 | + num_tokens=hidden_states.shape[0] * |
| 311 | + hidden_states.shape[1]): |
| 312 | + fwdctx = get_forward_context() |
| 313 | + fwdctx.no_compile_layers[''] = layer |
| 314 | + return layer(hidden_states, router_logits) |
| 315 | + |
| 316 | + validate(test_func, batch_sizes, (hidden_size, ), device, dtype) |
| 317 | + |
| 318 | + |
| 319 | +@multi_gpu_test(num_gpus=2) |
| 320 | +@pytest.mark.parametrize("world_size", [2]) |
| 321 | +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) |
| 322 | +@pytest.mark.parametrize("reduction_size", [1, 5, 1024, 1024 + 1]) |
| 323 | +@pytest.mark.parametrize("func", [ |
| 324 | + _test_column_parallel_linear, |
| 325 | + _test_row_parallel_linear, |
| 326 | + _test_rms_norm_needle_consistency, |
| 327 | + _test_fused_rms_norm_needle_consistency, |
| 328 | + _test_fused_moe_needle_consistency, |
| 329 | +]) |
| 330 | +def test_parallel_reduction_batch_invariance(world_size: int, |
| 331 | + dtype: torch.dtype, |
| 332 | + reduction_size: int, func): |
| 333 | + """Test parallel operators on 2 GPUs.""" |
| 334 | + test_config = { |
| 335 | + "dtype": dtype, |
| 336 | + "reduction_size": reduction_size, |
| 337 | + } |
| 338 | + |
| 339 | + mp.spawn(run_parallel_op_test_worker, |
| 340 | + args=(world_size, get_open_port(), test_config, func), |
| 341 | + nprocs=world_size) |
0 commit comments