Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/ops/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def flash_mla_with_kvcache(
descale_k is None
), "descale_q and descale_k should be both None or both not None"

if (descale_q is not None) and (descale_k is not None):
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
Comment on lines +139 to 142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly identifies FP8 tensors using q.element_size() == 1. However, it introduces a potential issue where the FP8 kernel could be called without the necessary scaling factors. If q is an FP8 tensor but descale_q is None, fwd_kvcache_mla_fp8 would be called with None for descale_q and descale_k. This would likely cause a crash or incorrect computation within the underlying C++ kernel, as FP8 operations require these scaling factors.

To make this function more robust, I suggest adding an assertion to ensure that descale_q and descale_k are provided when q is an FP8 tensor.

    if indices is None and q.element_size() == 1:
        assert descale_q is not None, (
            "descale_q and descale_k must be provided for fp8 attention")
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
            causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)

Expand Down
Loading