Skip to content

test: validate fused silu-and-mul fp8 quantization#264

Open
sxvvv wants to merge 4 commits into
MetaX-MACA:masterfrom
sxvvv:fix/fused-quant-activation-test-github
Open

test: validate fused silu-and-mul fp8 quantization#264
sxvvv wants to merge 4 commits into
MetaX-MACA:masterfrom
sxvvv:fix/fused-quant-activation-test-github

Conversation

@sxvvv

@sxvvv sxvvv commented May 31, 2026

Copy link
Copy Markdown

Summary

This PR improves kernel-test coverage and reproducibility for MetaX fused activation/quantization paths.

What changed

  • Adds/repairs coverage for the fused SiLU + multiply + fp8 quantization kernel.
  • Wraps SiluAndMul construction in a default VllmConfig context required by the current vLLM custom-op runtime.
  • Uses a fixed fp8 tolerance (atol=1/128, rtol=0) instead of strict float equality after fp8 output is converted to fp32.
  • Keeps opcheck coverage for torch.ops._C.silu_and_mul_quant.
  • Adds small test utility compatibility shims for helper APIs that moved across vLLM versions.
  • Implements MacaPlatform.seed_everything() so kernel tests can seed Python, NumPy, Torch, and CUDA consistently through the platform API.

Why this matters

Several kernel tests rely on current_platform.seed_everything(seed) before generating randomized inputs. On MetaX, that method was missing, so tests such as fp8 quantization and activation kernels failed before reaching the actual kernel assertions. Implementing it makes these tests deterministic and allows them to validate real kernel behavior on hardware.

The fused silu_and_mul_quant test then covers the combined SiLU activation, multiply, and fp8 quantization path across fp16/bf16 inputs and irregular shapes that are more likely to expose layout or rounding bugs.

Hardware validation

Environment:

  • GPU: MetaX C500
  • MACA: 3.5.3.20
  • torch: 2.8.0+metax3.5.3.9
  • vLLM: 0.17.0
  • fp8 dtype: torch.float8_e4m3fn

Validation commands run on the server:

python -m pytest -q tests/kernels/test_fused_quant_activation.py --confcutdir=tests/kernels -x -s
# 50 passed, 3 warnings in 17.89s

python -m pytest -q tests/kernels/quantization/test_fp8_quant.py --confcutdir=tests/kernels -x -s
# 109 passed, 3 warnings in 21.68s

Before seed_everything was added, test_fp8_quant.py failed immediately with:

TypeError: 'NoneType' object is not callable

at current_platform.seed_everything(seed). After the platform method was added, the fp8 quantization kernel suite completed successfully.

Representative direct fused silu_and_mul_quant cases also matched the native SiluAndMul.forward_native + ops.scaled_fp8_quant reference path with max_diff=0.0 for fp16 and bf16 irregular shapes including (tokens=1234, hidden=1562).

Additional checks

  • python -m py_compile vllm_metax/platform.py tests/kernels/test_fused_quant_activation.py tests/kernels/quant_utils.py tests/kernels/utils.py
  • git diff --check

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors kernel tests to improve robustness and reduce external dependencies by defining local helper functions (like round_up and make_tensor_with_pad) and wrapping activation tests with a mock VLLM configuration. The review feedback highlights several critical improvements: ensuring quantization scales are strictly positive to avoid division-by-zero, increasing the absolute tolerance (atol) for FP8 comparisons to prevent flaky tests, catching ImportError instead of ModuleNotFoundError for safer imports, and truncating input lists in make_tensor_with_pad to avoid potential runtime crashes.

scale = (torch.randn((1), device=device, dtype=torch.float32))
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
# Make inputs
scale = (torch.randn((1), device=device, dtype=torch.float32))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The quantization scale is generated using torch.randn, which can produce negative or zero values. Quantization scales must be strictly positive. A negative scale will incorrectly flip the signs of the quantized values, and a scale close to zero can cause division-by-zero or overflow issues. Consider using the absolute value of the random tensor and adding a small epsilon to ensure a valid, positive scale.

Suggested change
scale = (torch.randn((1), device=device, dtype=torch.float32))
scale = torch.randn((1), device=device, dtype=torch.float32).abs() + 1e-5

Comment on lines +72 to +75
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32),
atol=1 / 128,
rtol=0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a fixed atol=1 / 128 with rtol=0 for comparing FP8 quantized outputs can lead to flaky tests. Since FP8 is a highly discrete format, any tiny rounding difference in the intermediate float calculations between the fused kernel and the reference implementation will cause them to round to different FP8 buckets. The difference between adjacent FP8 values (1 ULP) for values around 1.0 is 0.125 (for E4M3), which is much larger than 1/128 (0.0078125). If even a single element rounds differently, the test will fail. Consider using a larger tolerance or allowing a small percentage of mismatched elements to prevent flaky CI runs.

Suggested change
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32),
atol=1 / 128,
rtol=0)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32),
atol=0.125,
rtol=0)

Comment thread tests/kernels/utils.py Outdated
Comment on lines +18 to +20
try:
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
except ModuleNotFoundError:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching only ModuleNotFoundError when importing from vllm.attention is less robust than catching ImportError. If vllm.attention exists but fails to import due to other issues (such as circular dependencies or missing sub-modules), an ImportError will be raised and the test suite will crash. Using except ImportError is more robust and consistent with the other import blocks in this file.

Suggested change
try:
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
except ModuleNotFoundError:
try:
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
except ImportError:

Comment thread tests/kernels/utils.py Outdated
dtype: torch.dtype,
device: Union[torch.device, str],
) -> torch.Tensor:
padded = [item + [pad] * (max_len - len(item)) for item in x]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If any list in x has a length greater than max_len, max_len - len(item) will be negative, resulting in [pad] * negative_number which evaluates to []. This means the item will not be truncated, and the resulting padded list will contain sublists of different lengths, causing torch.tensor(padded) to crash with a ValueError. To make this helper robust, truncate the items to max_len before padding.

Suggested change
padded = [item + [pad] * (max_len - len(item)) for item in x]
padded = [item[:max_len] + [pad] * (max_len - len(item)) for item in x]

@sxvvv sxvvv changed the title test: fix fused quant activation coverage test: cover fused silu-and-mul fp8 quantization May 31, 2026
@sxvvv sxvvv changed the title test: cover fused silu-and-mul fp8 quantization test: validate fused silu-and-mul fp8 quantization May 31, 2026
@sxvvv sxvvv force-pushed the fix/fused-quant-activation-test-github branch from 18b1e8a to 9b3a170 Compare May 31, 2026 04:22
sxvvv added 2 commits June 6, 2026 09:08
Incorporates Gemini Code Assist review on PR MetaX-MACA#264:
- Ensure the quantization scale is strictly positive (abs + epsilon).
- Use rtol=0.125 (1 ULP for e4m3) with a small atol for the fp8 compare
  instead of an absolute-only tolerance that was flaky.
- Catch ImportError (not just ModuleNotFoundError) for vllm.attention.
- Truncate over-length items in make_tensor_with_pad before padding.
Incorporates Gemini Code Assist review on PR MetaX-MACA#264:
- Ensure the quantization scale is strictly positive (abs + epsilon).
- Use rtol=0.125 (1 ULP for e4m3) with a small atol for the fp8 compare
  instead of an absolute-only tolerance that was flaky.
- Catch ImportError (not just ModuleNotFoundError) for vllm.attention.
- Truncate over-length items in make_tensor_with_pad before padding.
@sxvvv

sxvvv commented Jun 6, 2026

Copy link
Copy Markdown
Author

按 review 改了几处:scale 取绝对值加 1e-5 避免负数或接近 0;fp8 比较改用 rtol=0.125(约 e4m3 的 1 个 ULP)加一个小 atol;import 的 ModuleNotFoundError 换成 ImportError;make_tensor_with_pad 加了截断,防止 item 超长时报错。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant