-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[bugfix][deepseek] fix flashmla kernel selection #25956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug in the FlashMLA kernel selection by using q.element_size() == 1
to more reliably detect FP8 tensors, which is a good improvement for correctness. I have added one critical comment regarding a potential edge case where the FP8 kernel might be called without the necessary scaling factors, which could lead to a crash. Adding an assertion will make the implementation more robust.
if indices is None and q.element_size() == 1: | ||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( | ||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, | ||
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly identifies FP8 tensors using q.element_size() == 1
. However, it introduces a potential issue where the FP8 kernel could be called without the necessary scaling factors. If q
is an FP8 tensor but descale_q
is None
, fwd_kvcache_mla_fp8
would be called with None
for descale_q
and descale_k
. This would likely cause a crash or incorrect computation within the underlying C++ kernel, as FP8 operations require these scaling factors.
To make this function more robust, I suggest adding an assertion to ensure that descale_q
and descale_k
are provided when q
is an FP8 tensor.
if indices is None and q.element_size() == 1:
assert descale_q is not None, (
"descale_q and descale_k must be provided for fp8 attention")
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! thank you for the fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
okay, verified locally that it works. |
Signed-off-by: youkaichao <[email protected]> Signed-off-by: simon-mo <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]> Signed-off-by: yewentao256 <[email protected]>
Purpose
It seems we always pass in
descale_q
as tensors in flashmla backend, so it selects the wrong kernel implementation.Fixes #25896 (comment) and potentially #25896 (comment)
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.