diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 3cc0e4adfa0a..9654f9f6775a 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -136,7 +136,7 @@ def flash_mla_with_kvcache( descale_k is None ), "descale_q and descale_k should be both None or both not None" - if (descale_q is not None) and (descale_k is not None): + if indices is None and q.element_size() == 1: out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)