Skip to content

test: add real-GPU numerical test for mctlass W8A8 scaled_mm#268

Open
sxvvv wants to merge 1 commit into
MetaX-MACA:masterfrom
sxvvv:test/mctlass-w8a8-scaled-mm
Open

test: add real-GPU numerical test for mctlass W8A8 scaled_mm#268
sxvvv wants to merge 1 commit into
MetaX-MACA:masterfrom
sxvvv:test/mctlass-w8a8-scaled-mm

Conversation

@sxvvv
Copy link
Copy Markdown

@sxvvv sxvvv commented Jun 6, 2026

Purpose

Add the first test coverage for the MetaX mctlassEx W8A8 scaled_mm op
(mctlassEx_w8a8_scaled_mm_azp in
vllm_metax/model_executor/layers/quantization/_python_api_ops.py), which
currently has no tests. This op is on the active mctlass quantization path, so
it is worth guarding against silent numerical regressions and layout/scale
mistakes.

The test runs the real kernel on MetaX hardware and compares its output
against a PyTorch dequant + matmul reference. It is written to be CI-safe: the
module is skipped via pytest.importorskip when mctlassEx is unavailable, so
it collects cleanly on non-MetaX runners.

What it covers

  • Numerical correctness across several token/hidden shapes (incl. irregular
    M) and two seeds. Inputs start from realistic fp activations/weights, then
    go through per-token (a_scales) and per-channel (b_scales) int8
    quantization — mirroring how the op is actually used in a W8A8 model.
  • Scale linearity — doubling a_scales doubles the output.
  • K % 16 alignment constraint (see note below).

Reference math used (kernel takes b as (N, K)):

out[m, n] = a_scales[m] * b_scales[n] * sum_k a_q[m, k] * b_q[n, k]

Comparison uses a bf16-scaled tolerance (rtol = 2/128, atol scaled to output
magnitude), since the kernel accumulates in higher precision and rounds to
bfloat16.

A note on shape alignment (found while writing this)

While exercising shapes I found the kernel requires K and N to both be
multiples of 16
, and the two failure modes differ:

  • Misaligned K → the rule selector raises IndexError
    (Can not find valid kid by rule) before touching device memory. This is
    the safe case the test asserts.
  • Misaligned N → does not raise; it triggers a device-side
    Memory Violation (out-of-bounds write) that disables the MACA runtime for
    the rest of the process. The test therefore does not exercise misaligned
    N (it would poison the session), and callers must guarantee N % 16
    upstream.

Flagging the N case in case it is worth a defensive shape check (or a clearer
error) at the Python wrapper level.

Test Result

Validated on:

  • GPU: MetaX C500
  • MACA: 3.5.3.20
  • torch: 2.8.0+metax3.5.3.9
python -m pytest tests/kernels/quantization/test_mctlass_w8a8_scaled_mm.py -v
# 15 passed

On a non-MetaX environment the module is skipped (no mctlassEx), so it does
not affect existing CI.

Adds the first test coverage for the mctlassEx W8A8 scaled_mm op in
vllm_metax/model_executor/layers/quantization/_python_api_ops.py, which
previously had none. The test runs the real kernel on MetaX hardware and
checks output against a PyTorch dequant+matmul reference across several
token/hidden shapes and two seeds, plus a scale-linearity check, and
documents the kernel's K%16 alignment requirement.

Skips cleanly via importorskip when mctlassEx is absent, so it is safe to
collect on non-MetaX CI.

Validated on MetaX C500 (MACA 3.5.3.20, torch 2.8.0+metax3.5.3.9): 15 passed.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new test suite, test_mctlass_w8a8_scaled_mm.py, to verify the numerical correctness of the MetaX mctlass W8A8 scaled_mm operator. The tests cover symmetric per-token and per-channel int8 quantization, ensure that scales are applied correctly, and document the alignment constraints for the operator. There are no review comments, and I have no additional feedback to provide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@sxvvv
Copy link
Copy Markdown
Author

sxvvv commented Jun 6, 2026

给 mctlass 的 W8A8 scaled_mm 补了个真机测试,之前这个算子没有测试覆盖。在 C500 上和 PyTorch 参考实现对了几组 shape,15 个用例都过了;没有 mctlassEx 的环境会自动 skip,不影响现有 CI。

写的时候发现一个点想提一下:K 和 N 都要求是 16 的倍数。K 不对齐会安全报错(rule selector 抛 IndexError),但 N 不对齐不会报错,会直接 device 端 Memory Violation 越界、把 MACA runtime 弄挂。所以测试里没有去触发 N 不对齐的情况(会污染整个进程),调用方需要保证 N%16。要不要在 wrapper 层加个防御性的 shape 检查,你们看着定。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant