Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
sys.path.insert(0, '/tmp')
import patch_infer_schema
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Expand Down Expand Up @@ -40,7 +43,10 @@
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
try:
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
except ImportError:
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
Expand All @@ -51,7 +57,7 @@
from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.sequence import Logprob
from vllm.logprobs import Logprob
from vllm.transformers_utils.utils import maybe_model_redirect

logger = init_logger(__name__)
Expand Down Expand Up @@ -1386,3 +1392,15 @@ def image_urls(request, local_asset_server) -> list[str]:
"""Indirect fixture: takes a list of names, returns list of full URLs."""
names: list[str] = request.param
return [local_asset_server.url_for(name) for name in names]


@pytest.fixture
def default_vllm_config():
"""Set a default VllmConfig for tests that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.device import DeviceConfig

with set_current_vllm_config(VllmConfig(device_config=DeviceConfig("cuda"))):
yield
203 changes: 203 additions & 0 deletions tests/kernels/core/test_apply_rotary_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for ApplyRotaryEmb CustomOp dispatch behavior.

This test ensures that RotaryEmbedding classes correctly call the appropriate
ApplyRotaryEmb methods based on the calling context:

1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
"""

from dataclasses import dataclass

import pytest
import torch

from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.platforms import current_platform

CUDA_DEVICES = ["cuda:0"]


@dataclass
class RotaryEmbeddingTestCase:
"""Test case configuration for RotaryEmbedding dispatch tests."""

name: str
rope_class: type
rope_kwargs: dict
method_name: str # forward_native, forward_cuda, forward
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
expect_forward: bool # Should call ApplyRotaryEmb.forward()


def get_test_cases() -> list[RotaryEmbeddingTestCase]:
"""Generate test cases for all RotaryEmbedding classes."""
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
Ernie4_5_VLRotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding

common_kwargs = {
"head_size": 128,
"rotary_dim": 128,
"max_position_embeddings": 4096,
"base": 10000,
"is_neox_style": True,
"dtype": torch.bfloat16,
}

return [
# MRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_native",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_cuda_1d",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_cuda",
positions_shape=(32,), # 1D triggers apply_rotary_emb path
expect_forward_native=False,
expect_forward=True,
),
# XDRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="XDRotaryEmbedding.forward",
rope_class=XDRotaryEmbedding,
rope_kwargs={
**common_kwargs,
"scaling_alpha": 1.0,
"xdrope_section": [16, 16, 16, 16],
},
method_name="forward",
positions_shape=(4, 32), # 4D for P/W/H/T
expect_forward_native=False,
expect_forward=True,
),
# Ernie4_5_VLRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="Ernie4_5_VLRotaryEmbedding.forward_native",
rope_class=Ernie4_5_VLRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
]


def run_dispatch_test(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""Run a dispatch test for a RotaryEmbedding class."""
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
)
get_cached_compilation_config.cache_clear()

with set_current_vllm_config(vllm_config):
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)

apply_rotary_emb = rope.apply_rotary_emb

# Verify custom op is enabled
if test_case.expect_forward_native:
assert (
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
), "Test setup error: ApplyRotaryEmb custom op should be enabled"

# Setup call tracking
call_tracker = {"forward_native_called": False, "forward_called": False}
original_forward_native = apply_rotary_emb.forward_native
original_forward = apply_rotary_emb.forward

def tracked_forward_native(*args, **kwargs):
call_tracker["forward_native_called"] = True
return original_forward_native(*args, **kwargs)

def tracked_forward(*args, **kwargs):
call_tracker["forward_called"] = True
return original_forward(*args, **kwargs)

apply_rotary_emb.forward_native = tracked_forward_native
apply_rotary_emb.forward = tracked_forward

try:
num_tokens = test_case.positions_shape[-1]
num_q_heads = 8
num_kv_heads = 2
head_size = test_case.rope_kwargs["head_size"]
max_position = test_case.rope_kwargs["max_position_embeddings"]

positions = torch.randint(
0, max_position // 4, test_case.positions_shape, device=device
)
query = torch.randn(
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
)
key = torch.randn(
num_tokens,
num_kv_heads * head_size,
dtype=torch.bfloat16,
device=device,
)

# Call the method under test
method = getattr(rope, test_case.method_name)
method(positions, query.clone(), key.clone())

# Verify expectations
if test_case.expect_forward_native:
assert call_tracker["forward_native_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
)
if not test_case.expect_forward:
assert not call_tracker["forward_called"], (
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
"Bug: when +apply_rotary_emb is enabled, forward_native() "
"incorrectly dispatches to CUDA/HIP kernels."
)
if test_case.expect_forward:
assert call_tracker["forward_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward()"
)
finally:
apply_rotary_emb.forward_native = original_forward_native
apply_rotary_emb.forward = original_forward


@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_rotary_embedding_dispatch(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.

- forward_native methods should call ApplyRotaryEmb.forward_native()
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
"""
run_dispatch_test(test_case, device)
111 changes: 111 additions & 0 deletions tests/kernels/core/test_cpu_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import set_random_seed

if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)

from vllm.model_executor.layers.activation import (
GELU,
FastGELU,
GeluAndMul,
NewGELU,
QuickGELU,
SiluAndMul,
)

DTYPES = [torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83]
D = [512, 2048]
SEEDS = [0]


@pytest.mark.parametrize(
("activation_cls", "fn"),
[
(SiluAndMul, torch.ops._C.silu_and_mul),
(GeluAndMul, torch.ops._C.gelu_and_mul),
(GeluAndMul, torch.ops._C.gelu_tanh_and_mul),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_cpu_act_and_mul(
default_vllm_config,
activation_cls: type[torch.nn.Module],
fn: object,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
) -> None:
set_random_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)

layer = activation_cls()
out = layer(x)
ref_out = layer.forward_native(x)

torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)

output_shape = x.shape[:-1] + (x.shape[-1] // 2,)
raw_out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (raw_out, x))


@pytest.mark.parametrize(
("activation_cls", "fn", "op_args"),
[
(NewGELU, torch.ops._C.gelu_new, ()),
(FastGELU, torch.ops._C.gelu_fast, ()),
(QuickGELU, torch.ops._C.gelu_quick, ()),
pytest.param(
GELU,
getattr(torch.ops._C, "activation_lut_bf16", None),
("gelu",),
marks=pytest.mark.skipif(
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
reason="activation_lut_bf16 is only built on Arm CPU",
),
),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_cpu_unary_activation(
default_vllm_config,
activation_cls: type[torch.nn.Module],
fn: object,
op_args: tuple[str, ...],
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
) -> None:
set_random_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation_cls()
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
# gelu with activation_lut_bf16 only makes sense for BF16
if not (activation_cls is GELU and dtype != torch.bfloat16):
raw_out = torch.empty_like(x)
opcheck(fn, (raw_out, x, *op_args))
Loading