test: validate fused silu-and-mul fp8 quantization#264
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors kernel tests to improve robustness and reduce external dependencies by defining local helper functions (like round_up and make_tensor_with_pad) and wrapping activation tests with a mock VLLM configuration. The review feedback highlights several critical improvements: ensuring quantization scales are strictly positive to avoid division-by-zero, increasing the absolute tolerance (atol) for FP8 comparisons to prevent flaky tests, catching ImportError instead of ModuleNotFoundError for safer imports, and truncating input lists in make_tensor_with_pad to avoid potential runtime crashes.
| scale = (torch.randn((1), device=device, dtype=torch.float32)) | ||
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) | ||
| # Make inputs | ||
| scale = (torch.randn((1), device=device, dtype=torch.float32)) |
There was a problem hiding this comment.
The quantization scale is generated using torch.randn, which can produce negative or zero values. Quantization scales must be strictly positive. A negative scale will incorrectly flip the signs of the quantized values, and a scale close to zero can cause division-by-zero or overflow issues. Consider using the absolute value of the random tensor and adding a small epsilon to ensure a valid, positive scale.
| scale = (torch.randn((1), device=device, dtype=torch.float32)) | |
| scale = torch.randn((1), device=device, dtype=torch.float32).abs() + 1e-5 |
| assert torch.allclose(ref_out.to(dtype=torch.float32), | ||
| ops_out.to(dtype=torch.float32), | ||
| atol=1 / 128, | ||
| rtol=0) |
There was a problem hiding this comment.
Using a fixed atol=1 / 128 with rtol=0 for comparing FP8 quantized outputs can lead to flaky tests. Since FP8 is a highly discrete format, any tiny rounding difference in the intermediate float calculations between the fused kernel and the reference implementation will cause them to round to different FP8 buckets. The difference between adjacent FP8 values (1 ULP) for values around 1.0 is 0.125 (for E4M3), which is much larger than 1/128 (0.0078125). If even a single element rounds differently, the test will fail. Consider using a larger tolerance or allowing a small percentage of mismatched elements to prevent flaky CI runs.
| assert torch.allclose(ref_out.to(dtype=torch.float32), | |
| ops_out.to(dtype=torch.float32), | |
| atol=1 / 128, | |
| rtol=0) | |
| assert torch.allclose(ref_out.to(dtype=torch.float32), | |
| ops_out.to(dtype=torch.float32), | |
| atol=0.125, | |
| rtol=0) |
| try: | ||
| from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType | ||
| except ModuleNotFoundError: |
There was a problem hiding this comment.
Catching only ModuleNotFoundError when importing from vllm.attention is less robust than catching ImportError. If vllm.attention exists but fails to import due to other issues (such as circular dependencies or missing sub-modules), an ImportError will be raised and the test suite will crash. Using except ImportError is more robust and consistent with the other import blocks in this file.
| try: | |
| from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType | |
| except ModuleNotFoundError: | |
| try: | |
| from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType | |
| except ImportError: |
| dtype: torch.dtype, | ||
| device: Union[torch.device, str], | ||
| ) -> torch.Tensor: | ||
| padded = [item + [pad] * (max_len - len(item)) for item in x] |
There was a problem hiding this comment.
If any list in x has a length greater than max_len, max_len - len(item) will be negative, resulting in [pad] * negative_number which evaluates to []. This means the item will not be truncated, and the resulting padded list will contain sublists of different lengths, causing torch.tensor(padded) to crash with a ValueError. To make this helper robust, truncate the items to max_len before padding.
| padded = [item + [pad] * (max_len - len(item)) for item in x] | |
| padded = [item[:max_len] + [pad] * (max_len - len(item)) for item in x] |
18b1e8a to
9b3a170
Compare
Incorporates Gemini Code Assist review on PR MetaX-MACA#264: - Ensure the quantization scale is strictly positive (abs + epsilon). - Use rtol=0.125 (1 ULP for e4m3) with a small atol for the fp8 compare instead of an absolute-only tolerance that was flaky. - Catch ImportError (not just ModuleNotFoundError) for vllm.attention. - Truncate over-length items in make_tensor_with_pad before padding.
Incorporates Gemini Code Assist review on PR MetaX-MACA#264: - Ensure the quantization scale is strictly positive (abs + epsilon). - Use rtol=0.125 (1 ULP for e4m3) with a small atol for the fp8 compare instead of an absolute-only tolerance that was flaky. - Catch ImportError (not just ModuleNotFoundError) for vllm.attention. - Truncate over-length items in make_tensor_with_pad before padding.
|
按 review 改了几处:scale 取绝对值加 1e-5 避免负数或接近 0;fp8 比较改用 rtol=0.125(约 e4m3 的 1 个 ULP)加一个小 atol;import 的 ModuleNotFoundError 换成 ImportError;make_tensor_with_pad 加了截断,防止 item 超长时报错。 |
Summary
This PR improves kernel-test coverage and reproducibility for MetaX fused activation/quantization paths.
What changed
SiluAndMulconstruction in a defaultVllmConfigcontext required by the current vLLM custom-op runtime.atol=1/128,rtol=0) instead of strict float equality after fp8 output is converted to fp32.opcheckcoverage fortorch.ops._C.silu_and_mul_quant.MacaPlatform.seed_everything()so kernel tests can seed Python, NumPy, Torch, and CUDA consistently through the platform API.Why this matters
Several kernel tests rely on
current_platform.seed_everything(seed)before generating randomized inputs. On MetaX, that method was missing, so tests such as fp8 quantization and activation kernels failed before reaching the actual kernel assertions. Implementing it makes these tests deterministic and allows them to validate real kernel behavior on hardware.The fused
silu_and_mul_quanttest then covers the combined SiLU activation, multiply, and fp8 quantization path across fp16/bf16 inputs and irregular shapes that are more likely to expose layout or rounding bugs.Hardware validation
Environment:
torch.float8_e4m3fnValidation commands run on the server:
Before
seed_everythingwas added,test_fp8_quant.pyfailed immediately with:at
current_platform.seed_everything(seed). After the platform method was added, the fp8 quantization kernel suite completed successfully.Representative direct fused
silu_and_mul_quantcases also matched the nativeSiluAndMul.forward_native+ops.scaled_fp8_quantreference path withmax_diff=0.0for fp16 and bf16 irregular shapes including(tokens=1234, hidden=1562).Additional checks
python -m py_compile vllm_metax/platform.py tests/kernels/test_fused_quant_activation.py tests/kernels/quant_utils.py tests/kernels/utils.pygit diff --check