diff --git a/tests/conftest.py b/tests/conftest.py index deee177a1..f6089a011 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +import sys +sys.path.insert(0, '/tmp') +import patch_infer_schema # SPDX-License-Identifier: Apache-2.0 # 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. # SPDX-FileCopyrightText: Copyright contributors to the vLLM project @@ -40,7 +43,10 @@ from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype +try: + from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype +except ImportError: + from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, @@ -51,7 +57,7 @@ from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.sequence import Logprob +from vllm.logprobs import Logprob from vllm.transformers_utils.utils import maybe_model_redirect logger = init_logger(__name__) @@ -1386,3 +1392,15 @@ def image_urls(request, local_asset_server) -> list[str]: """Indirect fixture: takes a list of names, returns list of full URLs.""" names: list[str] = request.param return [local_asset_server.url_for(name) for name in names] + + +@pytest.fixture +def default_vllm_config(): + """Set a default VllmConfig for tests that directly test CustomOps or pathways + that use get_current_vllm_config() outside of a full engine context. + """ + from vllm.config import VllmConfig, set_current_vllm_config + from vllm.config.device import DeviceConfig + + with set_current_vllm_config(VllmConfig(device_config=DeviceConfig("cuda"))): + yield diff --git a/tests/kernels/core/test_apply_rotary_emb.py b/tests/kernels/core/test_apply_rotary_emb.py new file mode 100644 index 000000000..23c722fa5 --- /dev/null +++ b/tests/kernels/core/test_apply_rotary_emb.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for ApplyRotaryEmb CustomOp dispatch behavior. + +This test ensures that RotaryEmbedding classes correctly call the appropriate +ApplyRotaryEmb methods based on the calling context: + +1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native() +2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch) +3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch) +""" + +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ( + CompilationConfig, + VllmConfig, + get_cached_compilation_config, + set_current_vllm_config, +) +from vllm.platforms import current_platform + +CUDA_DEVICES = ["cuda:0"] + + +@dataclass +class RotaryEmbeddingTestCase: + """Test case configuration for RotaryEmbedding dispatch tests.""" + + name: str + rope_class: type + rope_kwargs: dict + method_name: str # forward_native, forward_cuda, forward + positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens) + expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native() + expect_forward: bool # Should call ApplyRotaryEmb.forward() + + +def get_test_cases() -> list[RotaryEmbeddingTestCase]: + """Generate test cases for all RotaryEmbedding classes.""" + from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding, + ) + from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding + from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding + + common_kwargs = { + "head_size": 128, + "rotary_dim": 128, + "max_position_embeddings": 4096, + "base": 10000, + "is_neox_style": True, + "dtype": torch.bfloat16, + } + + return [ + # MRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_native", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_cuda_1d", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_cuda", + positions_shape=(32,), # 1D triggers apply_rotary_emb path + expect_forward_native=False, + expect_forward=True, + ), + # XDRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="XDRotaryEmbedding.forward", + rope_class=XDRotaryEmbedding, + rope_kwargs={ + **common_kwargs, + "scaling_alpha": 1.0, + "xdrope_section": [16, 16, 16, 16], + }, + method_name="forward", + positions_shape=(4, 32), # 4D for P/W/H/T + expect_forward_native=False, + expect_forward=True, + ), + # Ernie4_5_VLRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="Ernie4_5_VLRotaryEmbedding.forward_native", + rope_class=Ernie4_5_VLRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + ] + + +def run_dispatch_test( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """Run a dispatch test for a RotaryEmbedding class.""" + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"]) + ) + get_cached_compilation_config.cache_clear() + + with set_current_vllm_config(vllm_config): + rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device) + + apply_rotary_emb = rope.apply_rotary_emb + + # Verify custom op is enabled + if test_case.expect_forward_native: + assert ( + apply_rotary_emb._forward_method != apply_rotary_emb.forward_native + ), "Test setup error: ApplyRotaryEmb custom op should be enabled" + + # Setup call tracking + call_tracker = {"forward_native_called": False, "forward_called": False} + original_forward_native = apply_rotary_emb.forward_native + original_forward = apply_rotary_emb.forward + + def tracked_forward_native(*args, **kwargs): + call_tracker["forward_native_called"] = True + return original_forward_native(*args, **kwargs) + + def tracked_forward(*args, **kwargs): + call_tracker["forward_called"] = True + return original_forward(*args, **kwargs) + + apply_rotary_emb.forward_native = tracked_forward_native + apply_rotary_emb.forward = tracked_forward + + try: + num_tokens = test_case.positions_shape[-1] + num_q_heads = 8 + num_kv_heads = 2 + head_size = test_case.rope_kwargs["head_size"] + max_position = test_case.rope_kwargs["max_position_embeddings"] + + positions = torch.randint( + 0, max_position // 4, test_case.positions_shape, device=device + ) + query = torch.randn( + num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device + ) + key = torch.randn( + num_tokens, + num_kv_heads * head_size, + dtype=torch.bfloat16, + device=device, + ) + + # Call the method under test + method = getattr(rope, test_case.method_name) + method(positions, query.clone(), key.clone()) + + # Verify expectations + if test_case.expect_forward_native: + assert call_tracker["forward_native_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward_native()" + ) + if not test_case.expect_forward: + assert not call_tracker["forward_called"], ( + f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). " + "Bug: when +apply_rotary_emb is enabled, forward_native() " + "incorrectly dispatches to CUDA/HIP kernels." + ) + if test_case.expect_forward: + assert call_tracker["forward_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward()" + ) + finally: + apply_rotary_emb.forward_native = original_forward_native + apply_rotary_emb.forward = original_forward + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_rotary_embedding_dispatch( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """ + Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method. + + - forward_native methods should call ApplyRotaryEmb.forward_native() + - forward_cuda/forward methods should call ApplyRotaryEmb.forward() + """ + run_dispatch_test(test_case, device) diff --git a/tests/kernels/core/test_cpu_activation.py b/tests/kernels/core/test_cpu_activation.py new file mode 100644 index 000000000..40b5f0454 --- /dev/null +++ b/tests/kernels/core/test_cpu_activation.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.allclose_default import get_default_atol, get_default_rtol +from tests.kernels.utils import opcheck +from vllm.platforms import CpuArchEnum, current_platform +from vllm.utils.torch_utils import set_random_seed + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +from vllm.model_executor.layers.activation import ( + GELU, + FastGELU, + GeluAndMul, + NewGELU, + QuickGELU, + SiluAndMul, +) + +DTYPES = [torch.bfloat16, torch.float32] +NUM_TOKENS = [7, 83] +D = [512, 2048] +SEEDS = [0] + + +@pytest.mark.parametrize( + ("activation_cls", "fn"), + [ + (SiluAndMul, torch.ops._C.silu_and_mul), + (GeluAndMul, torch.ops._C.gelu_and_mul), + (GeluAndMul, torch.ops._C.gelu_tanh_and_mul), + ], +) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_cpu_act_and_mul( + default_vllm_config, + activation_cls: type[torch.nn.Module], + fn: object, + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, +) -> None: + set_random_seed(seed) + x = torch.randn(num_tokens, 2 * d, dtype=dtype) + + layer = activation_cls() + out = layer(x) + ref_out = layer.forward_native(x) + + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) + ) + + output_shape = x.shape[:-1] + (x.shape[-1] // 2,) + raw_out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + opcheck(fn, (raw_out, x)) + + +@pytest.mark.parametrize( + ("activation_cls", "fn", "op_args"), + [ + (NewGELU, torch.ops._C.gelu_new, ()), + (FastGELU, torch.ops._C.gelu_fast, ()), + (QuickGELU, torch.ops._C.gelu_quick, ()), + pytest.param( + GELU, + getattr(torch.ops._C, "activation_lut_bf16", None), + ("gelu",), + marks=pytest.mark.skipif( + current_platform.get_cpu_architecture() != CpuArchEnum.ARM, + reason="activation_lut_bf16 is only built on Arm CPU", + ), + ), + ], +) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_cpu_unary_activation( + default_vllm_config, + activation_cls: type[torch.nn.Module], + fn: object, + op_args: tuple[str, ...], + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, +) -> None: + set_random_seed(seed) + x = torch.randn(num_tokens, d, dtype=dtype) + layer = activation_cls() + out = layer(x) + ref_out = layer.forward_native(x) + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) + ) + # gelu with activation_lut_bf16 only makes sense for BF16 + if not (activation_cls is GELU and dtype != torch.bfloat16): + raw_out = torch.empty_like(x) + opcheck(fn, (raw_out, x, *op_args)) diff --git a/tests/kernels/core/test_fused_q_kv_rmsnorm.py b/tests/kernels/core/test_fused_q_kv_rmsnorm.py new file mode 100644 index 000000000..f11458fb9 --- /dev/null +++ b/tests/kernels/core/test_fused_q_kv_rmsnorm.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Correctness + large-token-count launch tests for fused_q_kv_rmsnorm. + +Before the grid-dim fix the kernel used grid ``(2, num_tokens)``, which hit +CUDA's 65535 grid-y cap for ``num_tokens >= 65536`` and failed with +``Triton Error [CUDA]: invalid argument`` at every large chunked-prefill +profile run. These tests pin the new grid layout. +""" + +from __future__ import annotations + +import pytest +import torch + +from vllm.platforms import current_platform +try: + from vllm.v1.attention.ops.deepseek_v4_ops import fused_q_kv_rmsnorm +except ImportError: + pytest.skip("deepseek_v4_ops not available", allow_module_level=True) + +pytestmark = pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="fused_q_kv_rmsnorm requires a CUDA/ROCm device", +) + + +def _ref_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor: + x_f32 = x.to(torch.float32) + variance = x_f32.pow(2).mean(dim=-1, keepdim=True) + y = x_f32 * torch.rsqrt(variance + eps) * w.to(torch.float32) + return y.to(x.dtype) + + +@pytest.mark.parametrize("num_tokens", [1, 17, 1024, 8192]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_fused_q_kv_rmsnorm_correctness(num_tokens: int, dtype: torch.dtype): + torch.manual_seed(0) + device = "cuda" + q_size, kv_size = 192, 576 + qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device) + kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device) + qw = torch.randn(q_size, dtype=dtype, device=device) + kvw = torch.randn(kv_size, dtype=dtype, device=device) + eps = 1e-6 + + qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, eps) + + qr_ref = _ref_rmsnorm(qr, qw, eps) + kv_ref = _ref_rmsnorm(kv, kvw, eps) + + tol = dict(rtol=1e-2, atol=1e-2) + torch.testing.assert_close(qr_out, qr_ref, **tol) + torch.testing.assert_close(kv_out, kv_ref, **tol) + + +@pytest.mark.parametrize("num_tokens", [65535, 65536, 131072]) +def test_fused_q_kv_rmsnorm_launches_past_grid_y_cap(num_tokens: int): + """Regression guard: grid used to be (2, num_tokens), hitting CUDA's + 65535 grid-y cap at num_tokens >= 65536. The new grid (num_tokens, 2) + lifts that bound to 2**31-1.""" + device = "cuda" + dtype = torch.bfloat16 + q_size, kv_size = 192, 576 + qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device) + kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device) + qw = torch.randn(q_size, dtype=dtype, device=device) + kvw = torch.randn(kv_size, dtype=dtype, device=device) + + qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, 1e-6) + # spot-check a couple of rows against the torch reference + for row in (0, num_tokens // 2, num_tokens - 1): + torch.testing.assert_close( + qr_out[row], + _ref_rmsnorm(qr[row : row + 1], qw, 1e-6)[0], + rtol=1e-2, + atol=1e-2, + ) + torch.testing.assert_close( + kv_out[row], + _ref_rmsnorm(kv[row : row + 1], kvw, 1e-6)[0], + rtol=1e-2, + atol=1e-2, + ) diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py new file mode 100644 index 000000000..43737f4f2 --- /dev/null +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + +DTYPES = [torch.bfloat16, torch.float16] +IS_NEOX = [True, False] +EPS_VALUES = [1e-5, 1e-6] +SEEDS = [13] +PARTIAL_ROPE = [True, False] +CUDA_DEVICES = ["cuda:0"] + + +def _apply_qk_norm_rope( + qkv: torch.Tensor, + positions: torch.Tensor, + q_norm: RMSNorm, + k_norm: RMSNorm, + rope: RotaryEmbedding, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, +) -> torch.Tensor: + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + + q, k = rope.forward_native(positions, q, k) + return torch.cat([q, k, v], dim=-1) + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="fused_qk_norm_rope custom op requires cuda and rocm platform", +) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox", IS_NEOX) +@pytest.mark.parametrize("eps", EPS_VALUES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25]) +@torch.inference_mode() +def test_fused_qk_norm_rope_matches_reference( + default_vllm_config, + device: str, + dtype: torch.dtype, + is_neox: bool, + eps: float, + seed: int, + rotary_ratio: float, +): + torch.set_default_device(device) + set_random_seed(seed) + num_heads, num_kv_heads, head_dim = 16, 4, 128 + num_tokens = 4 + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + qkv_fused = qkv_base.clone() + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_norm.weight.data.normal_(mean=1.0, std=0.1) + k_norm.weight.data.normal_(mean=1.0, std=0.1) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + rotary_dim = int(head_dim * rotary_ratio) + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=rotary_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + ref_result = _apply_qk_norm_rope( + qkv=qkv_base, + positions=positions, + q_norm=q_norm, + k_norm=k_norm, + rope=rope, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + opcheck( + torch.ops._C.fused_qk_norm_rope, + ( + qkv_fused.clone(), + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ), + ) + + torch.ops._C.fused_qk_norm_rope( + qkv_fused, + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close( + qkv_fused, + ref_result, + atol=ATOL, + rtol=RTOL, + ) diff --git a/tests/kernels/core/test_fused_rms_norm_gated.py b/tests/kernels/core/test_fused_rms_norm_gated.py new file mode 100644 index 000000000..793dd02a9 --- /dev/null +++ b/tests/kernels/core/test_fused_rms_norm_gated.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests that FusedRMSNormGated decomposes correctly under torch.compile, +matching the eager triton kernel output.""" + +import pytest +import torch + +from vllm.model_executor.layers.fla.ops.kda import FusedRMSNormGated +from vllm.utils.torch_utils import set_random_seed + +DTYPES = [torch.bfloat16] +HIDDEN_SIZES = [128, 512] +NUM_TOKENS = [64, 128] +ACTIVATIONS = ["swish", "sigmoid"] +ELEMENTWISE_AFFINE = [True, False] +SEEDS = [0] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("elementwise_affine", ELEMENTWISE_AFFINE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_compiled_vs_eager( + default_vllm_config, + num_tokens: int, + hidden_size: int, + activation: str, + elementwise_affine: bool, + dtype: torch.dtype, + seed: int, +) -> None: + """forward_native decomposition matches forward_cuda triton kernel.""" + torch._dynamo.reset() + set_random_seed(seed) + device = torch.device("cuda:0") + + module = FusedRMSNormGated( + hidden_size, + elementwise_affine=elementwise_affine, + eps=1e-5, + activation=activation, + device=device, + dtype=dtype, + ) + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + g = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + + # forward_cuda may modify x in-place, so clone inputs + cuda_out = module.forward_cuda(x.clone(), g.clone()) + compiled_native = torch.compile(module.forward_native, fullgraph=True) + native_out = compiled_native(x.clone(), g.clone()) + + torch.testing.assert_close(native_out, cuda_out, atol=1e-3, rtol=1e-2) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 16, 32, 128), + (2, 8, 16, 64), + ], +) +@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("elementwise_affine", ELEMENTWISE_AFFINE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_compiled_vs_eager_multidim( + default_vllm_config, + shape: tuple, + activation: str, + elementwise_affine: bool, + dtype: torch.dtype, + seed: int, +) -> None: + """forward_native decomposition handles multi-dimensional inputs.""" + torch._dynamo.reset() + set_random_seed(seed) + device = torch.device("cuda:0") + head_dim = shape[-1] + + module = FusedRMSNormGated( + head_dim, + elementwise_affine=elementwise_affine, + eps=1e-5, + activation=activation, + device=device, + dtype=dtype, + ) + x = torch.randn(*shape, dtype=dtype, device=device) + g = torch.randn(*shape, dtype=dtype, device=device) + + # forward_cuda may modify x in-place, so clone inputs + cuda_out = module.forward_cuda(x.clone(), g.clone()) + compiled_native = torch.compile(module.forward_native, fullgraph=True) + native_out = compiled_native(x.clone(), g.clone()) + + torch.testing.assert_close(native_out, cuda_out, atol=1e-3, rtol=1e-2) diff --git a/tests/kernels/core/test_fused_silu_mul_block_quant.py b/tests/kernels/core/test_fused_silu_mul_block_quant.py new file mode 100644 index 000000000..37b76056c --- /dev/null +++ b/tests/kernels/core/test_fused_silu_mul_block_quant.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +import vllm._custom_ops as ops +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, +) +from vllm.platforms import current_platform + +DTYPES = [torch.float16, torch.bfloat16] +QUANT_DTYPES = [current_platform.fp8_dtype(), torch.int8] +VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] +NUM_TOKENS_HIDDEN_SIZES = [ + *[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]], + *[(16, i) for i in [64, *VEC_HIDDEN_SIZES, 5120]], + *[(128, i) for i in [64, *VEC_HIDDEN_SIZES]], + *[(512, i) for i in [64, 5120]], +] +SCALE_UBS = [False] +GROUP_SIZES = [64, 128] +IS_SCALE_TRANSPOSED = [False, True] +SEEDS = [0] +CUDA_DEVICES = [i for i in range(1 if torch.accelerator.device_count() == 1 else 2)] + + +def ref_silu_and_mul_per_block_quant( + x: torch.Tensor, + quant_dtype: torch.dtype, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reference implementation: unfused SiLU+Mul then group quantization.""" + hidden = x.shape[-1] // 2 + gate, up = x.split(hidden, dim=-1) + silu_out = F.silu(gate) * up + + if quant_dtype == current_platform.fp8_dtype(): + return per_token_group_quant_fp8( + silu_out, group_size=group_size, use_ue8m0=False + ) + elif quant_dtype == torch.int8: + return per_token_group_quant_int8(silu_out, group_size=group_size) + else: + raise ValueError(f"Unsupported quant_dtype: {quant_dtype}") + + +@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) +@pytest.mark.parametrize("has_scale_ub", SCALE_UBS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("group_size", GROUP_SIZES) +@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device_idx", CUDA_DEVICES) +@torch.inference_mode() +def test_silu_and_mul_per_block_quant( + default_vllm_config, + num_tokens: int, + hidden_size: int, + has_scale_ub: bool, + dtype: torch.dtype, + quant_dtype: torch.dtype, + group_size: int, + is_scale_transposed: bool, + seed: int, + device_idx: str, +) -> None: + """Test SiLU+Mul+Block Quantization kernel correctness.""" + torch.accelerator.set_device_index(device_idx) + device = f"cuda:{device_idx}" + torch.random.manual_seed(seed) + torch.set_default_device(device) + + if hidden_size % group_size != 0: + return + + if has_scale_ub: + pytest.skip("Scale upper bound not yet supported") + + scale = 1 / hidden_size + x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) * scale + + # Reference implementation + ref_out, ref_scales = ref_silu_and_mul_per_block_quant(x, quant_dtype, group_size) + + # Fused kernel implementation + ops_out, ops_scales = ops.silu_and_mul_per_block_quant( + x, group_size, quant_dtype, None, is_scale_transposed + ) + + # Check for NaN/Inf + assert not torch.isnan(ops_out.float()).any(), "Kernel output contains NaN" + assert not torch.isinf(ops_out.float()).any(), "Kernel output contains Inf" + assert not torch.isnan(ops_scales).any(), "Kernel scales contain NaN" + assert not torch.isinf(ops_scales).any(), "Kernel scales contain Inf" + + # Check dtypes + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + + # Check scales match + torch.testing.assert_close(ref_scales, ops_scales, rtol=1e-5, atol=1e-5) + + # Check output correctness via dequantized values + ref_scales_expanded = ref_scales.repeat_interleave(group_size, dim=1) + ops_scales_expanded = ops_scales.repeat_interleave(group_size, dim=1) + ref_deq = ref_out.to(dtype=torch.float32) * ref_scales_expanded + ops_deq = ops_out.to(dtype=torch.float32) * ops_scales_expanded + torch.testing.assert_close(ref_deq, ops_deq, atol=5e-2, rtol=5e-2) + + # opcheck + output = torch.empty(num_tokens, hidden_size, device=device, dtype=quant_dtype) + num_groups = hidden_size // group_size + if is_scale_transposed: + scales = torch.empty(num_groups, num_tokens, device=device, dtype=torch.float32) + else: + scales = torch.empty(num_tokens, num_groups, device=device, dtype=torch.float32) + opcheck( + torch.ops._C.silu_and_mul_per_block_quant, + (output, x, scales, group_size, None, is_scale_transposed), + ) + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("num_tokens", [128]) +@pytest.mark.parametrize("group_size", [128]) +def test_silu_block_quant_shapes( + default_vllm_config, + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + group_size: int, +): + """Test that output shapes are correct.""" + torch.set_default_device("cuda") + x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device="cuda") + + # Row-major scales + out, scales = ops.silu_and_mul_per_block_quant( + x, + group_size=group_size, + quant_dtype=current_platform.fp8_dtype(), + is_scale_transposed=False, + ) + assert out.shape == (num_tokens, hidden_size) + assert scales.shape == (num_tokens, hidden_size // group_size) + + # Column-major scales (logical shape same after .t() in _custom_ops) + out, scales = ops.silu_and_mul_per_block_quant( + x, + group_size=group_size, + quant_dtype=current_platform.fp8_dtype(), + is_scale_transposed=True, + ) + assert out.shape == (num_tokens, hidden_size) + assert scales.shape == (num_tokens, hidden_size // group_size) + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("batch_size", [1, 16, 256]) +@pytest.mark.parametrize("hidden_size", [1024, 5120, 14336]) +def test_silu_block_quant_edge_cases( + default_vllm_config, dtype: torch.dtype, batch_size: int, hidden_size: int +): + """Test edge cases: single token, large batch, large hidden size.""" + torch.set_default_device("cuda") + x = torch.randn(batch_size, hidden_size * 2, dtype=dtype, device="cuda") + + out, scales = ops.silu_and_mul_per_block_quant( + x, + group_size=128, + quant_dtype=current_platform.fp8_dtype(), + is_scale_transposed=False, + ) + + assert out.shape == (batch_size, hidden_size) + assert out.dtype == current_platform.fp8_dtype() + assert scales.dtype == torch.float32 + assert not torch.isnan(out.float()).any() + assert not torch.isnan(scales).any() + assert not torch.isinf(scales).any() diff --git a/tests/kernels/core/test_minimax_reduce_rms.py b/tests/kernels/core/test_minimax_reduce_rms.py new file mode 100644 index 000000000..d17a448bd --- /dev/null +++ b/tests/kernels/core/test_minimax_reduce_rms.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for MiniMax QK RMS-norm: NCCL reference vs Lamport fused kernel.""" + +import pytest +import torch +import torch.nn as nn +from torch.multiprocessing import spawn + +from tests.kernels.utils import opcheck +from tests.utils import ensure_current_vllm_config, init_test_distributed_environment +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP +from vllm.platforms import current_platform +from vllm.utils.network_utils import get_open_port +from vllm.utils.torch_utils import set_random_seed + + +@ensure_current_vllm_config() +def _worker_forward_qk( + local_rank, + world_size, + port, + num_tokens, + hidden_q_full, + hidden_k_full, + dtype, + seed, + eps, +): + """Per-rank worker: compare NCCL allreduce path vs Lamport fused kernel.""" + + if not hasattr(torch.ops._C, "minimax_allreduce_rms_qk"): + cleanup_dist_env_and_memory() + return + device = torch.device(f"cuda:{local_rank}") + torch.accelerator.set_device_index(device) + init_test_distributed_environment( + world_size, 1, local_rank, port, local_rank=local_rank + ) + + hq = hidden_q_full // world_size + hk = hidden_k_full // world_size + + q_norm = MiniMaxText01RMSNormTP(hidden_q_full, eps=eps).cuda() + k_norm = MiniMaxText01RMSNormTP(hidden_k_full, eps=eps).cuda() + + set_random_seed(seed) + qw = torch.randn(hidden_q_full, dtype=dtype, device="cuda") + kw = torch.randn(hidden_k_full, dtype=dtype, device="cuda") + q_norm.weight = nn.Parameter(qw[local_rank * hq : (local_rank + 1) * hq]) + k_norm.weight = nn.Parameter(kw[local_rank * hk : (local_rank + 1) * hk]) + + torch.manual_seed(seed + 1000 + local_rank) + qkv = torch.randn(num_tokens, hq + hk + hk, dtype=dtype, device="cuda") + + q_ref, k_ref, v_ref = qkv.clone().split([hq, hk, hk], dim=-1) + ref_q, ref_k = MiniMaxText01RMSNormTP.forward_qk(q_norm, k_norm, q_ref, k_ref) + + # Set up Lamport workspace. + from vllm.distributed.parallel_state import get_tp_group + from vllm.model_executor.layers.mamba.lamport_workspace import ( + get_allreduce_workspace, + ) + + workspace = get_allreduce_workspace( + rank=local_rank, + world_size=world_size, + max_tokens=num_tokens, + process_group=get_tp_group().cpu_group, + ) + + opcheck( + torch.ops._C.minimax_allreduce_rms_qk, + ( + qkv.clone(), + q_norm.weight, + k_norm.weight, + workspace, + hq, + hk, + local_rank, + world_size, + eps, + ), + ) + fused_q, fused_k = torch.ops._C.minimax_allreduce_rms_qk( + qkv.clone(), + q_norm.weight, + k_norm.weight, + workspace, + hq, + hk, + local_rank, + world_size, + eps, + ) + _, _, fused_v = qkv.split([hq, hk, hk], dim=-1) + torch.accelerator.synchronize() + + torch.testing.assert_close( + fused_q, + ref_q, + atol=3e-2, + rtol=3e-2, + ) + torch.testing.assert_close(fused_k, ref_k, atol=3e-2, rtol=3e-2) + + cleanup_dist_env_and_memory() + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="CUDA required", +) +@pytest.mark.parametrize("world_size", [2, 4, 8]) +@pytest.mark.parametrize("num_tokens", [1, 128, 333]) +@pytest.mark.parametrize( + "hidden_dims", + [(6144, 1024)], +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("eps", [1e-6]) +@pytest.mark.parametrize("seed", [42]) +def test_minimax_reduce_rms_qk( + world_size, + num_tokens, + hidden_dims, + dtype, + eps, + seed, +): + num_gpus = current_platform.device_count() + if num_gpus < world_size: + pytest.skip(f"Need >= {world_size} GPUs, have {num_gpus}") + hidden_q_full, hidden_k_full = hidden_dims + port = str(get_open_port()) + spawn( + _worker_forward_qk, + args=( + world_size, + port, + num_tokens, + hidden_q_full, + hidden_k_full, + dtype, + seed, + eps, + ), + nprocs=world_size, + join=True, + ) diff --git a/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py b/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py new file mode 100644 index 000000000..181f10f31 --- /dev/null +++ b/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for fused MLA KV-cache write and RoPE fused kernel +""" + +import random + +import pytest +import torch + +from tests.kernels.allclose_default import get_default_atol, get_default_rtol +from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + + +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) +@pytest.mark.parametrize("is_neox_style", [False, True]) +@pytest.mark.parametrize("seq_len", [11, 42]) +@pytest.mark.parametrize("qk_rope_head_dim", [64, 128]) +@pytest.mark.parametrize("num_q_heads", [128]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("num_blocks", [64]) +@pytest.mark.parametrize("block_size", [16, 64, 256]) +@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize( + "device", + [f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)], +) +@torch.inference_mode() +def test_concat_and_cache_mla_rope_fused( + default_vllm_config, + dtype: torch.dtype, + is_neox_style: bool, + seq_len: int, + qk_rope_head_dim: int, + num_q_heads: int, + kv_cache_dtype: str, + kv_lora_rank: int, + num_blocks: int, + block_size: int, + seed: int, + device: str, + max_position: int = 8192, + base: float = 10000, +) -> None: + set_random_seed(seed) + torch.set_default_device(device) + + rope = RotaryEmbedding( + qk_rope_head_dim, + qk_rope_head_dim, + max_position, + base, + is_neox_style, + torch.float32, + ) + + rope = rope.to(dtype=dtype, device=torch.get_default_device()) + + positions = torch.randint(0, max_position, (seq_len,)) + + query = torch.randn(seq_len, num_q_heads, qk_rope_head_dim, dtype=dtype) + key = torch.randn(seq_len, 1, qk_rope_head_dim + kv_lora_rank, dtype=dtype) + + k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device) + kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device) + + if current_platform.is_rocm(): + # We use forward_hip for the same numerics as the fused custom kernel on ROCm + # when dtype is FP16. The torch-native implementation implicitly upcasts + # FP16 x FP16 multiplications to FP32 before downcasting them, which leads + # to notable output divergences. + # Clone the tensors because the implementation modifies them in-place + ref_q_pe, ref_k_pe = rope.forward_hip(positions, query.clone(), k_pe.clone()) + else: + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe) + assert ref_k_pe is not None + + ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device) + ref_k_rope = ref_k_pe[..., :qk_rope_head_dim] + + total_available_slots = num_blocks * block_size + total_needed_slots = seq_len + assert total_available_slots >= total_needed_slots, "Not enough kv slots!" + + slot_mapping_lst = random.sample(range(total_available_slots), total_needed_slots) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + entry_size = kv_lora_rank + qk_rope_head_dim + + kv_cache_scale = torch.tensor([0.1], dtype=torch.float32, device=device) + + kv_cache = torch.zeros( + num_blocks, + block_size, + entry_size, + dtype=torch.uint8 if kv_cache_dtype == "fp8" else dtype, + device=device, + ) + + ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) + + for i in range(seq_len): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + ref_temp[block_idx, block_offset] = torch.cat((kv_c[i], ref_k_rope[i]), -1) + + if kv_cache_dtype == "fp8": + ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) + ops.convert_fp8( + ref_kv_cache, ref_temp, kv_cache_scale.item(), kv_dtype=kv_cache_dtype + ) + else: + ref_kv_cache = ref_temp + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused, + ( + positions, + query, + k_pe, + kv_c, + rope.cos_sin_cache, + is_neox_style, + slot_mapping, + kv_cache, + kv_cache_dtype, + kv_cache_scale, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla_rope_fused( + positions, + query, + k_pe, + kv_c, + rope.cos_sin_cache, + is_neox_style, + slot_mapping, + kv_cache, + kv_cache_dtype, + kv_cache_scale, + ) + + if kv_cache_dtype == "fp8": + result_temp = torch.empty_like(kv_cache, dtype=torch.float16) + ops.convert_fp8( + result_temp, + kv_cache.contiguous(), + kv_cache_scale.item(), + kv_dtype=kv_cache_dtype, + ) + expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16) + ops.convert_fp8( + expected_temp, ref_kv_cache, kv_cache_scale.item(), kv_dtype=kv_cache_dtype + ) + torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1) + else: + torch.testing.assert_close(kv_cache, ref_kv_cache) + + torch.testing.assert_close( + query, ref_q_pe, atol=get_default_atol(query), rtol=get_default_rtol(query) + ) diff --git a/tests/kernels/core/test_vit_bilinear_pos_embed.py b/tests/kernels/core/test_vit_bilinear_pos_embed.py new file mode 100644 index 000000000..d0354d25f --- /dev/null +++ b/tests/kernels/core/test_vit_bilinear_pos_embed.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Accuracy tests for the fused Triton bilinear position-embedding kernel. + +Compares ``triton_pos_embed_interpolate`` against the pure-PyTorch +``pos_embed_interpolate_native`` across a variety of grid shapes and dtypes. +""" + +import pytest +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + try: + from vllm.model_executor.models.qwen3_vl import ( + pos_embed_interpolate_native, + triton_pos_embed_interpolate, + ) + except ImportError: + pytest.skip("qwen3_vl kernels not available", allow_module_level=True) + + +DTYPES = [torch.float32, torch.bfloat16] +# Qwen3-VL default +NUM_GRID_PER_SIDE = 48 +SPATIAL_MERGE_SIZE = 2 +HIDDEN_DIM = 1152 + +# 4 square + 4 non-square grids (h, w divisible by spatial_merge_size=2) +SQUARE_GRIDS = [(1, 4, 4), (1, 16, 16), (1, 32, 32), (1, 48, 48)] +NON_SQUARE_GRIDS = [(1, 8, 16), (1, 14, 20), (1, 32, 48), (1, 60, 80)] +ALL_GRIDS = SQUARE_GRIDS + NON_SQUARE_GRIDS + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1]) +@pytest.mark.parametrize( + "grid_thw", + ALL_GRIDS, + ids=[f"{t}x{h}x{w}" for t, h, w in ALL_GRIDS], +) +def test_triton_matches_native( + grid_thw: tuple[int, int, int], + dtype: torch.dtype, +) -> None: + """Triton kernel output must match the native PyTorch implementation.""" + t, h, w = grid_thw + device = "cuda" + + # Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23). + torch.manual_seed(42) + embed_weight = ( + torch.randn( + NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE, + HIDDEN_DIM, + device=device, + dtype=dtype, + ) + * 0.25 + ) + + native_out = pos_embed_interpolate_native( + embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype + ) + triton_out = triton_pos_embed_interpolate( + embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype + ) + + assert native_out.shape == triton_out.shape, ( + f"Shape mismatch: native {native_out.shape} vs triton {triton_out.shape}" + ) + + # Small numerical differences arise from the precomputed h/w_scale + # in the triton kernel vs torch.linspace in the native path, which can + # cause single-ULP output differences + # in a handful of elements. + atol = {torch.float32: 5e-5, torch.bfloat16: 1e-2}[dtype] + rtol = {torch.float32: 1e-5, torch.bfloat16: 1e-2}[dtype] + torch.testing.assert_close(triton_out, native_out, atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1]) +def test_temporal_repeat(dtype: torch.dtype) -> None: + """Verify temporal dimension t > 1 correctly repeats the spatial pattern.""" + device = "cuda" + h, w = 16, 16 + t_single, t_multi = 1, 3 + + # Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23). + torch.manual_seed(42) + embed_weight = ( + torch.randn( + NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE, + HIDDEN_DIM, + device=device, + dtype=dtype, + ) + * 0.25 + ) + + out_single = triton_pos_embed_interpolate( + embed_weight, + t_single, + h, + w, + NUM_GRID_PER_SIDE, + SPATIAL_MERGE_SIZE, + dtype, + ) + out_multi = triton_pos_embed_interpolate( + embed_weight, + t_multi, + h, + w, + NUM_GRID_PER_SIDE, + SPATIAL_MERGE_SIZE, + dtype, + ) + + expected = out_single.repeat(t_multi, 1) + torch.testing.assert_close(out_multi, expected, atol=0, rtol=0) diff --git a/tests/kernels/core/test_vit_fp8_attn.py b/tests/kernels/core/test_vit_fp8_attn.py new file mode 100644 index 000000000..43948ee80 --- /dev/null +++ b/tests/kernels/core/test_vit_fp8_attn.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the full FP8 ViT attention path (quantize -> cuDNN -> un-pad).""" + +import contextlib + +import pytest +import torch + +from vllm.triton_utils import HAS_TRITON +try: + from vllm.utils.flashinfer import ( + is_flashinfer_cudnn_fp8_prefill_attn_supported, + ) +except ImportError: + pytest.skip("flashinfer cudnn fp8 not available", allow_module_level=True) +from vllm.v1.attention.backends.registry import AttentionBackendEnum + + +def _has_flashinfer_cudnn() -> bool: + """Check if FlashInfer cuDNN backend is available.""" + try: + from flashinfer.prefill import ( + cudnn_batch_prefill_with_kv_cache, # noqa: F401 + ) + + return True + except ImportError: + return False + + +HEAD_DIMS = [72, 80] +SEQ_LENS = [256] +NUM_HEADS = [16] + + +@pytest.fixture +def _fp8_attention(): + """Create FP8-enabled MMEncoderAttention via config.""" + from types import SimpleNamespace + from unittest.mock import patch + + from vllm.config import VllmConfig, set_current_vllm_config + from vllm.config.multimodal import MultiModalConfig + + if not is_flashinfer_cudnn_fp8_prefill_attn_supported(): + pytest.skip("FlashInfer cuDNN FP8 prefill attention not supported") + + mm_config = MultiModalConfig(mm_encoder_attn_dtype="fp8") + vllm_config = VllmConfig() + vllm_config.model_config = SimpleNamespace(multimodal_config=mm_config) + + # MMEncoderAttention reads torch.get_default_dtype() during init + # to determine the output dtype. In real model loading this is bf16. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + + with ( + set_current_vllm_config(vllm_config), + patch( + "vllm.model_executor.layers.attention.mm_encoder_attention" + ".get_vit_attn_backend", + return_value=AttentionBackendEnum.FLASHINFER, + ), + ): + yield + + torch.set_default_dtype(old_dtype) + + +def _build_cu_seqlens_and_meta( + seq_len: int, + num_heads: int, + head_dim: int, + fp8_padded_hidden_size: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build cu_seqlens, max_seqlen, sequence_lengths for a single sequence.""" + import numpy as np + + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, + ) + + cu_seqlens_np = np.array([0, seq_len], dtype=np.int32) + + sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens( + AttentionBackendEnum.FLASHINFER, + cu_seqlens_np, + torch.device("cuda"), + ) + + max_seqlen = torch.tensor( + MMEncoderAttention.compute_max_seqlen( + AttentionBackendEnum.FLASHINFER, cu_seqlens_np + ), + dtype=torch.int32, + ) + + cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + AttentionBackendEnum.FLASHINFER, + cu_seqlens_np, + num_heads * head_dim, + 1, # tp_size + torch.device("cuda"), + fp8_padded_hidden_size=fp8_padded_hidden_size, + ) + + return cu_seqlens, max_seqlen, sequence_lengths + + +@pytest.mark.skipif( + not (HAS_TRITON and _has_flashinfer_cudnn()), + reason="Triton and FlashInfer cuDNN required", +) +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +def test_fp8_attn_output_shape( + head_dim: int, + seq_len: int, + num_heads: int, + _fp8_attention, +) -> None: + """Verify FP8 attention produces correct output shape after un-padding.""" + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, + ) + from vllm.utils.math_utils import round_up + + attn = None + with contextlib.suppress(ValueError, ImportError): + attn = MMEncoderAttention( + num_heads=num_heads, + head_size=head_dim, + prefix="visual.blocks.0.attn", + ).to("cuda") + + if attn is None or not attn.fp8_enabled: + pytest.skip("FP8 MMEncoderAttention not available") + assert attn is not None # mypy narrowing + + # FP8 always needs fp8_padded_hidden_size for correct cu_seqlens + fp8_padded_hidden_size = num_heads * round_up(head_dim, 16) + + cu_seqlens, max_seqlen, sequence_lengths = _build_cu_seqlens_and_meta( + seq_len, num_heads, head_dim, fp8_padded_hidden_size=fp8_padded_hidden_size + ) + + q = torch.randn( + seq_len, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + k = torch.randn_like(q) + v = torch.randn_like(q) + + output = attn._forward_flashinfer(q, k, v, cu_seqlens, max_seqlen, sequence_lengths) + + # Output should have original head_dim (un-padded) + assert output.shape[-1] == head_dim + assert output.dtype == torch.bfloat16 + + +@pytest.mark.skipif( + not (HAS_TRITON and _has_flashinfer_cudnn()), + reason="Triton and FlashInfer cuDNN required", +) +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +def test_fp8_vs_bf16_close( + head_dim: int, seq_len: int, num_heads: int, _fp8_attention +) -> None: + """FP8 attention output should be reasonably close to BF16 baseline.""" + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, + ) + from vllm.utils.math_utils import round_up + + torch.manual_seed(42) + q = torch.randn( + 1, + seq_len, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + k = torch.randn_like(q) + v = torch.randn_like(q) + + # FP8 path + attn_fp8 = None + with contextlib.suppress(ValueError, ImportError): + attn_fp8 = MMEncoderAttention( + num_heads=num_heads, + head_size=head_dim, + prefix="visual.blocks.0.attn", + ).to("cuda") + + if attn_fp8 is None or not attn_fp8.fp8_enabled: + pytest.skip("FP8 MMEncoderAttention not available") + assert attn_fp8 is not None # mypy narrowing + + fp8_padded_hidden_size = num_heads * round_up(head_dim, 16) + cu_seqlens, max_seqlen, seq_lengths = _build_cu_seqlens_and_meta( + seq_len, + num_heads, + head_dim, + fp8_padded_hidden_size=fp8_padded_hidden_size, + ) + + out_fp8 = attn_fp8._forward_flashinfer( + q.clone(), + k.clone(), + v.clone(), + cu_seqlens, + max_seqlen, + seq_lengths, + ) + + # BF16 baseline (create non-FP8 attention by using scale=attn_fp8.scale + # and calling the wrapper directly without FP8 quantization) + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + _get_flashinfer_workspace_buffer, + ) + from vllm.v1.attention.ops.vit_attn_wrappers import ( + vit_flashinfer_wrapper, + ) + + out_bf16 = vit_flashinfer_wrapper( + q=q.clone(), + k=k.clone(), + v=v.clone(), + scale=attn_fp8.scale, + workspace_buffer=_get_flashinfer_workspace_buffer(), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=seq_lengths, + ) + + out_fp8_f = out_fp8.float() + out_bf16_f = out_bf16.float() + + abs_diff = (out_fp8_f - out_bf16_f).abs() + abs_diff_flat = abs_diff.flatten() + + # Relative diff (avoid division by zero) + denom = out_bf16_f.abs().clamp(min=1e-6) + rel_diff_flat = (abs_diff / denom).flatten() + + cosine_sim = torch.nn.functional.cosine_similarity( + out_fp8_f.flatten().unsqueeze(0), + out_bf16_f.flatten().unsqueeze(0), + ).item() + + pcts = [50, 90, 95, 99, 99.9] + abs_pct = {p: torch.quantile(abs_diff_flat, p / 100).item() for p in pcts} + rel_pct = {p: torch.quantile(rel_diff_flat, p / 100).item() for p in pcts} + + print(f"\nFP8 vs BF16 (head_dim={head_dim}, seq_len={seq_len}):") + print(f" cosine_sim={cosine_sim:.6f}") + print( + f" abs_diff: max={abs_diff_flat.max().item():.6f}, " + f"mean={abs_diff_flat.mean().item():.6f}, " + + ", ".join(f"p{p}={abs_pct[p]:.6f}" for p in pcts) + ) + print( + f" rel_diff: max={rel_diff_flat.max().item():.6f}, " + f"mean={rel_diff_flat.mean().item():.6f}, " + + ", ".join(f"p{p}={rel_pct[p]:.6f}" for p in pcts) + ) + + assert abs_diff_flat.max().item() < 0.3, ( + f"FP8 vs BF16 max abs diff too large: {abs_diff_flat.max().item()}" + ) + assert abs_diff_flat.mean().item() < 0.03, ( + f"FP8 vs BF16 mean abs diff too large: {abs_diff_flat.mean().item()}" + ) + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim:.6f}" diff --git a/tests/kernels/core/test_vit_fp8_quant.py b/tests/kernels/core/test_vit_fp8_quant.py new file mode 100644 index 000000000..e5507370d --- /dev/null +++ b/tests/kernels/core/test_vit_fp8_quant.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the stride-aware FP8 quantization kernel with head_dim padding.""" + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + try: + from vllm.kernels.triton.qkv_padded_fp8_quant import ( + quantize_fp8_pad_head_dim_triton, + ) + except ImportError: + pytest.skip("qkv_padded_fp8_quant not available", allow_module_level=True) + +HEAD_DIMS = [72, 80, 128] +SEQ_LENS = [64, 256] +NUM_HEADS = [16] +SCALES = [0.01, 0.1, 1.0] + + +def _naive_fp8_quantize( + tensor: torch.Tensor, scale: torch.Tensor, skip_scale: bool +) -> torch.Tensor: + """Reference FP8 quantization in PyTorch.""" + fp8_dtype = current_platform.fp8_dtype() + fp8_max = torch.finfo(fp8_dtype).max + fp8_min = -fp8_max + + x = tensor.float() + if not skip_scale: + x = x / scale.item() + x = x.clamp(fp8_min, fp8_max) + return x.to(fp8_dtype) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("scale_val", SCALES) +def test_quantize_contiguous( + head_dim: int, seq_len: int, num_heads: int, scale_val: float +) -> None: + """Test quantization of contiguous 3D tensors.""" + torch.manual_seed(42) + tensor = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda").view( + 1, 1, 1, 1 + ) + + result = quantize_fp8_pad_head_dim_triton(tensor, scale) + + padded_dim = (head_dim + 15) // 16 * 16 + assert result.shape == (seq_len, num_heads, padded_dim) + assert result.is_contiguous() + assert result.dtype == current_platform.fp8_dtype() + + # Compare unpadded portion against reference + ref = _naive_fp8_quantize(tensor, scale, skip_scale=False) + torch.testing.assert_close(result[:, :, :head_dim].float(), ref.float()) + + # Padded region should be zero + if padded_dim > head_dim: + assert (result[:, :, head_dim:].float() == 0).all() + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("head_dim", [72, 80]) +def test_quantize_non_contiguous(head_dim: int) -> None: + """Test quantization from non-contiguous QKV views (interleaved buffer).""" + seq_len, num_heads = 64, 16 + # Simulate interleaved QKV buffer: shape (seq_len, 3 * num_heads, head_dim) + qkv = torch.randn( + seq_len, 3 * num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + # Q is every 3rd head slice - non-contiguous view + q = qkv[:, 0::3, :] + assert not q.is_contiguous() + + scale = torch.tensor([0.1], dtype=torch.float32, device="cuda").view(1, 1, 1, 1) + result = quantize_fp8_pad_head_dim_triton(q, scale) + + padded_dim = (head_dim + 15) // 16 * 16 + assert result.shape == (seq_len, num_heads, padded_dim) + assert result.is_contiguous() + + # Compare against contiguous reference + ref = _naive_fp8_quantize(q.contiguous(), scale, skip_scale=False) + torch.testing.assert_close(result[:, :, :head_dim].float(), ref.float()) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +def test_skip_scale() -> None: + """Test skip_scale=True produces cast-only output (no division).""" + seq_len, num_heads, head_dim = 32, 8, 80 + tensor = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda").view(1, 1, 1, 1) + + result_skip = quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=True) + result_noskip = quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=False) + + # skip_scale should just cast, not divide + ref_cast = _naive_fp8_quantize(tensor, scale, skip_scale=True) + torch.testing.assert_close(result_skip[:, :, :head_dim].float(), ref_cast.float()) + + # With scale != 1.0, skip and no-skip should differ + assert not torch.equal(result_skip.float(), result_noskip.float()) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +def test_4d_input() -> None: + """Test that 4D input (B, S, H, D) is handled correctly.""" + B, S, H, D = 2, 32, 8, 72 + tensor = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + scale = torch.tensor([0.1], dtype=torch.float32, device="cuda").view(1, 1, 1, 1) + + result = quantize_fp8_pad_head_dim_triton(tensor, scale) + padded_dim = (D + 15) // 16 * 16 + assert result.shape == (B, S, H, padded_dim) diff --git a/tests/kernels/core/test_vit_fp8_scaling.py b/tests/kernels/core/test_vit_fp8_scaling.py new file mode 100644 index 000000000..19b9e57db --- /dev/null +++ b/tests/kernels/core/test_vit_fp8_scaling.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for FP8 scaling (dynamic and static) in MMEncoderAttention.""" + +import contextlib +import json +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +import torch + +try: + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + _FP8_AMAX_HISTORY_LEN, + _FP8_MAX, + ) +except ImportError: + pytest.skip("mm_encoder_attention fp8 constants not available", allow_module_level=True) +try: + from vllm.utils.flashinfer import ( + is_flashinfer_cudnn_fp8_prefill_attn_supported, + ) +except ImportError: + pytest.skip("flashinfer cudnn fp8 not available", allow_module_level=True) + +LAYER_0 = "visual.blocks.0.attn.attn" +LAYER_1 = "visual.blocks.1.attn.attn" +NUM_HEADS = 16 +HEAD_DIM = 72 + + +@contextlib.contextmanager +def _build_attention(mm_config): + """Yield an MMEncoderAttention with the given multimodal config. + + The VllmConfig context stays active while the test runs so that + ``get_multimodal_config()`` calls during the forward path resolve. Also + invokes ``process_weights_after_loading`` to simulate the model loader's + auto-scan. Yields ``None`` if FlashInfer cuDNN is not available. + """ + from vllm.config import VllmConfig, set_current_vllm_config + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, + ) + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + if not is_flashinfer_cudnn_fp8_prefill_attn_supported(): + yield None + return + + vllm_config = VllmConfig() + vllm_config.model_config = SimpleNamespace(multimodal_config=mm_config) + + with ( + set_current_vllm_config(vllm_config), + patch( + "vllm.model_executor.layers.attention.mm_encoder_attention" + ".get_vit_attn_backend", + return_value=AttentionBackendEnum.FLASHINFER, + ), + ): + attn = MMEncoderAttention( + num_heads=NUM_HEADS, + head_size=HEAD_DIM, + prefix=LAYER_0, + ) + attn.process_weights_after_loading(torch.bfloat16) + yield attn + + +@pytest.fixture +def _make_attention(): + """Create an MMEncoderAttention with dynamic FP8 scaling.""" + from vllm.config.multimodal import MultiModalConfig + + with _build_attention(MultiModalConfig(mm_encoder_attn_dtype="fp8")) as attn: + yield attn + + +@pytest.fixture +def _make_static_attention(tmp_path): + """Create an MMEncoderAttention with static FP8 scales from a file.""" + from vllm.config.multimodal import MultiModalConfig + + scale_file = tmp_path / "scales.json" + scale_file.write_text( + json.dumps( + { + LAYER_0: {"q": 224.0, "k": 198.0, "v": 210.0}, + LAYER_1: {"q": 100.0, "k": 110.0, "v": 120.0}, + } + ) + ) + with _build_attention( + MultiModalConfig( + mm_encoder_attn_dtype="fp8", + mm_encoder_fp8_scale_path=str(scale_file), + ) + ) as attn: + yield attn + + +def test_dynamic_scaling_updates_scales(_make_attention) -> None: + """Verify that _record_amax_and_update_scales updates scale buffers.""" + attn = _make_attention + if attn is None or not attn.fp8_enabled: + pytest.skip("FP8 attention not available (FlashInfer backend required)") + + attn = attn.to("cuda") + + S, H, D = 32, NUM_HEADS, HEAD_DIM + q = torch.full((S, H, D), 2.0, device="cuda", dtype=torch.bfloat16) + k = torch.full((S, H, D), 3.0, device="cuda", dtype=torch.bfloat16) + v = torch.full((S, H, D), 4.0, device="cuda", dtype=torch.bfloat16) + + attn._record_amax_and_update_scales(q, k, v) + + expected_q_scale = 2.0 / _FP8_MAX + expected_k_scale = 3.0 / _FP8_MAX + expected_v_scale = 4.0 / _FP8_MAX + + torch.testing.assert_close(attn._fp8_q_scale.item(), expected_q_scale) + torch.testing.assert_close(attn._fp8_k_scale.item(), expected_k_scale) + torch.testing.assert_close(attn._fp8_v_scale.item(), expected_v_scale) + + +def test_circular_buffer_wraps(_make_attention) -> None: + """Verify the amax circular buffer wraps at HISTORY_LEN.""" + attn = _make_attention + if attn is None or not attn.fp8_enabled: + pytest.skip("FP8 attention not available (FlashInfer backend required)") + + attn = attn.to("cuda") + S, H, D = 16, NUM_HEADS, HEAD_DIM + + for i in range(_FP8_AMAX_HISTORY_LEN + 2): + mag = float(i + 1) + q = torch.full((S, H, D), mag, device="cuda", dtype=torch.bfloat16) + k = torch.full((S, H, D), mag, device="cuda", dtype=torch.bfloat16) + v = torch.full((S, H, D), mag, device="cuda", dtype=torch.bfloat16) + attn._record_amax_and_update_scales(q, k, v) + + assert attn._fp8_amax_pos == 2 + + expected_max = float(_FP8_AMAX_HISTORY_LEN + 2) + expected_scale = expected_max / _FP8_MAX + torch.testing.assert_close(attn._fp8_q_scale.item(), expected_scale) + + +def test_static_scales_loaded(_make_static_attention) -> None: + """Verify static scales are loaded from the JSON file.""" + attn = _make_static_attention + if attn is None or not attn.fp8_enabled: + pytest.skip("FP8 attention not available (FlashInfer backend required)") + + assert attn.fp8_enabled + assert not attn._fp8_dynamic_scale + + # Layer 0 scales (the layer this attention was created with). + assert attn._fp8_q_scale.item() == 224.0 + assert attn._fp8_k_scale.item() == 198.0 + assert attn._fp8_v_scale.item() == 210.0 + + assert not attn.skip_scale_q + assert not attn.skip_scale_k + assert not attn.skip_scale_v + + # No amax history buffers for static scaling. + assert not hasattr(attn, "_fp8_q_amax") + + +def test_static_scales_missing_layer(tmp_path) -> None: + """Verify error when requested layer is not in the scale file.""" + from vllm.config import VllmConfig, set_current_vllm_config + from vllm.config.multimodal import MultiModalConfig + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + if not is_flashinfer_cudnn_fp8_prefill_attn_supported(): + pytest.skip("FlashInfer cuDNN not available") + + scale_file = tmp_path / "wrong_layer.json" + scale_file.write_text( + json.dumps({"visual.blocks.99.attn": {"q": 1.0, "k": 1.0, "v": 1.0}}) + ) + mm_config = MultiModalConfig( + mm_encoder_attn_dtype="fp8", + mm_encoder_fp8_scale_path=str(scale_file), + ) + vllm_config = VllmConfig() + vllm_config.model_config = SimpleNamespace(multimodal_config=mm_config) + + from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, + ) + + with ( + set_current_vllm_config(vllm_config), + patch( + "vllm.model_executor.layers.attention.mm_encoder_attention" + ".get_vit_attn_backend", + return_value=AttentionBackendEnum.FLASHINFER, + ), + ): + attn = MMEncoderAttention( + num_heads=NUM_HEADS, + head_size=HEAD_DIM, + prefix=LAYER_0, + ) + with pytest.raises(ValueError, match="scales not found for layer"): + attn.process_weights_after_loading(torch.bfloat16) + + +def test_dynamic_scales_auto_save(tmp_path) -> None: + """Verify scales are saved to disk after the amax buffer fills.""" + import vllm.model_executor.layers.attention.mm_encoder_attention as _mod + from vllm.config.multimodal import MultiModalConfig + + if not is_flashinfer_cudnn_fp8_prefill_attn_supported(): + pytest.skip("FlashInfer cuDNN not available") + + # Reset module-level state between runs (other tests may have left + # state behind after triggering a save). + _mod._fp8_scale_save_path = None + _mod._fp8_saved_scale_refs.clear() + + save_file = tmp_path / "auto_scales.json" + with _build_attention( + MultiModalConfig( + mm_encoder_attn_dtype="fp8", + mm_encoder_fp8_scale_save_path=str(save_file), + ) + ) as attn: + if attn is None or not attn.fp8_enabled: + pytest.skip("FP8 attention not available") + + attn = attn.to("cuda") + S, H, D = 16, NUM_HEADS, HEAD_DIM + + # Run exactly _FP8_AMAX_HISTORY_LEN forward passes. + for i in range(_FP8_AMAX_HISTORY_LEN): + mag = float(i + 1) + q = torch.full((S, H, D), mag, device="cuda", dtype=torch.bfloat16) + k = torch.full((S, H, D), mag * 0.5, device="cuda", dtype=torch.bfloat16) + v = torch.full((S, H, D), mag * 0.3, device="cuda", dtype=torch.bfloat16) + attn._record_amax_and_update_scales(q, k, v) + + # File should have been written on the 16th call (buffer wrap). + assert save_file.is_file(), "Scale file was not saved" + scales = json.loads(save_file.read_text()) + assert LAYER_0 in scales + assert set(scales[LAYER_0].keys()) == {"q", "k", "v"} + for val in scales[LAYER_0].values(): + assert isinstance(val, float) and val > 0 + + # Path is cleared after the one-shot save fires. + assert _mod._fp8_scale_save_path is None diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index e0c3947c7..e254e04c6 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) from vllm.platforms import current_platform -from vllm.utils import round_up +from vllm.utils.math_utils import round_up # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 01e1b9f2a..1ad4bf24f 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,13 +15,25 @@ from torch._prims_common import TensorLikeType from tests.kernels.quant_utils import native_w8a8_block_matmul -from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +try: + from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +except ImportError: + from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) -from vllm.platforms.interface import _Backend -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) +try: + from vllm.platforms.interface import _Backend +except ImportError: + from typing import Any as _Backend +try: + from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) +except ImportError: + STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" + STR_FLASH_ATTN_VAL = "FLASH_ATTN" + STR_XFORMERS_ATTN_VAL = "XFORMERS" + from vllm.utils.torch_utils import make_tensor_with_pad # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. diff --git a/tests/models/registry.py b/tests/models/registry.py index 4884fdb3a..c446e0fbb 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -11,7 +11,14 @@ from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION -from vllm.config import ModelDType, TokenizerMode +try: + from vllm.config import TokenizerMode +except ImportError: + from vllm.config.model import TokenizerMode +try: + from vllm.config import ModelDType +except ImportError: + from vllm.config.model import ModelDType @dataclass(frozen=True) diff --git a/tests/models/utils.py b/tests/models/utils.py index 9d32dfaf4..c608a38ed 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -11,9 +11,19 @@ import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.config import ModelConfig, ModelDType, RunnerOption -from vllm.inputs import InputContext -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.config import ModelConfig +try: + from vllm.config import ModelDType, RunnerOption +except ImportError: + from vllm.config.model import ModelDType, RunnerOption +try: + from vllm.inputs import InputContext +except ImportError: + class InputContext: + def __init__(self, model_config): + self.model_config = model_config +from vllm.logprobs import Logprob +from vllm.logprobs import PromptLogprobs, SampleLogprobs from .registry import HF_EXAMPLE_MODELS diff --git a/tests/utils.py b/tests/utils.py index d371083fc..27cba01d9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -40,8 +40,13 @@ from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import (FlexibleArgumentParser, GB_bytes, - cuda_device_count_stateless, get_open_port) +try: + from vllm.utils import FlexibleArgumentParser, GB_bytes, cuda_device_count_stateless, get_open_port +except ImportError: + from vllm.utils.argparse_utils import FlexibleArgumentParser + from vllm.utils.mem_constants import GB_bytes + from vllm.utils.torch_utils import cuda_device_count_stateless + from vllm.utils.network_utils import get_open_port if current_platform.is_rocm(): from amdsmi import (amdsmi_get_gpu_vram_usage, @@ -1151,3 +1156,14 @@ def override_cutlass_fp8_supported(value: bool): "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", return_value=value): yield + + +# Compatibility shim for tests imported from newer upstream vLLM +class ensure_current_vllm_config: + """Dummy context manager/decorator for backward compatibility.""" + def __call__(self, func): + return func + def __enter__(self): + pass + def __exit__(self, *args): + pass