test: add real-GPU numerical test for mctlass W8A8 scaled_mm#268
Conversation
Adds the first test coverage for the mctlassEx W8A8 scaled_mm op in vllm_metax/model_executor/layers/quantization/_python_api_ops.py, which previously had none. The test runs the real kernel on MetaX hardware and checks output against a PyTorch dequant+matmul reference across several token/hidden shapes and two seeds, plus a scale-linearity check, and documents the kernel's K%16 alignment requirement. Skips cleanly via importorskip when mctlassEx is absent, so it is safe to collect on non-MetaX CI. Validated on MetaX C500 (MACA 3.5.3.20, torch 2.8.0+metax3.5.3.9): 15 passed.
There was a problem hiding this comment.
Code Review
This pull request introduces a new test suite, test_mctlass_w8a8_scaled_mm.py, to verify the numerical correctness of the MetaX mctlass W8A8 scaled_mm operator. The tests cover symmetric per-token and per-channel int8 quantization, ensure that scales are applied correctly, and document the alignment constraints for the operator. There are no review comments, and I have no additional feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
给 mctlass 的 W8A8 scaled_mm 补了个真机测试,之前这个算子没有测试覆盖。在 C500 上和 PyTorch 参考实现对了几组 shape,15 个用例都过了;没有 mctlassEx 的环境会自动 skip,不影响现有 CI。 写的时候发现一个点想提一下:K 和 N 都要求是 16 的倍数。K 不对齐会安全报错(rule selector 抛 IndexError),但 N 不对齐不会报错,会直接 device 端 Memory Violation 越界、把 MACA runtime 弄挂。所以测试里没有去触发 N 不对齐的情况(会污染整个进程),调用方需要保证 N%16。要不要在 wrapper 层加个防御性的 shape 检查,你们看着定。 |
Purpose
Add the first test coverage for the MetaX mctlassEx W8A8
scaled_mmop(
mctlassEx_w8a8_scaled_mm_azpinvllm_metax/model_executor/layers/quantization/_python_api_ops.py), whichcurrently has no tests. This op is on the active mctlass quantization path, so
it is worth guarding against silent numerical regressions and layout/scale
mistakes.
The test runs the real kernel on MetaX hardware and compares its output
against a PyTorch dequant + matmul reference. It is written to be CI-safe: the
module is skipped via
pytest.importorskipwhenmctlassExis unavailable, soit collects cleanly on non-MetaX runners.
What it covers
M) and two seeds. Inputs start from realistic fp activations/weights, thengo through per-token (
a_scales) and per-channel (b_scales) int8quantization — mirroring how the op is actually used in a W8A8 model.
a_scalesdoubles the output.K % 16alignment constraint (see note below).Reference math used (kernel takes
bas(N, K)):Comparison uses a bf16-scaled tolerance (
rtol = 2/128,atolscaled to outputmagnitude), since the kernel accumulates in higher precision and rounds to
bfloat16.
A note on shape alignment (found while writing this)
While exercising shapes I found the kernel requires
KandNto both bemultiples of 16, and the two failure modes differ:
K→ the rule selector raisesIndexError(
Can not find valid kid by rule) before touching device memory. This isthe safe case the test asserts.
N→ does not raise; it triggers a device-sideMemory Violation(out-of-bounds write) that disables the MACA runtime forthe rest of the process. The test therefore does not exercise misaligned
N(it would poison the session), and callers must guaranteeN % 16upstream.
Flagging the
Ncase in case it is worth a defensive shape check (or a clearererror) at the Python wrapper level.
Test Result
Validated on:
On a non-MetaX environment the module is skipped (no
mctlassEx), so it doesnot affect existing CI.