I encountered an nan issue when replacing this version FlashMLA in vllm.
This nan case only occurred with mtp num_speculative_token=1, namely q_len=2.
python pengcuo_test_flash_mla.py
[512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
b=128, s_q=2, h_q=8, h_kv=1, d=576, dv=512, causal=True, 0.056 , 41 , 2291519488, 1435, 80278784, mean_sk=512,
b=128, s_q=2, h_q=8, h_kv=1, d=576, dv=512, causal=True, 0.082 , 58 , 4725645312, 1971, 160819712, mean_sk=1024,
b=128, s_q=2, h_q=8, h_kv=1, d=576, dv=512, causal=True, 0.131 , 73 , 9596194816, 2463, 321977600, mean_sk=2048,
b=128, s_q=2, h_q=8, h_kv=1, d=576, dv=512, causal=True, 0.213 , 85 , 18163089408, 2841, 605441024, mean_sk=4096,
b=128, s_q=2, h_q=8, h_kv=1, d=576, dv=512, causal=True, 0.404 , 88 , 35761985536, 2938, 1187757440, mean_sk=8192,
b=128, s_q=2, h_q=8, h_kv=1, d=576, dv=512, causal=True, 0.866 , 88 , 76151922688, 2915, 2524189184, mean_sk=16384,
Traceback (most recent call last):
File "/mnt/moonfs/xusuting-v6/workspace/xst-FlashMLA-ETAP/tests/pengcuo_test_flash_mla.py", line 243, in <module>
main(torch_dtype, block_size = args.bs)
File "/mnt/moonfs/xusuting-v6/workspace/xst-FlashMLA-ETAP/tests/pengcuo_test_flash_mla.py", line 211, in main
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, block_size)
File "/opt/conda/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/moonfs/xusuting-v6/workspace/xst-FlashMLA-ETAP/tests/pengcuo_test_flash_mla.py", line 162, in test_flash_mla
cal_diff(out_flash, out_torch, "out")
File "/mnt/moonfs/xusuting-v6/workspace/xst-FlashMLA-ETAP/tests/pengcuo_test_flash_mla.py", line 65, in cal_diff
assert cos_diff < 2 * 1e-5
AssertionError
I encountered an nan issue when replacing this version FlashMLA in vllm.
This nan case only occurred with mtp num_speculative_token=1, namely q_len=2.