Skip to content

fix: [https://nvbugspro.nvidia.com/bug/5243482] If FlashMLA is used, the existence of FMHA based MLA kernels should not be checked. #3862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2416,7 +2416,7 @@ int AttentionOp::initialize() noexcept
fmhaParams.numTokensPerBlock = mTokensPerBlock;
fmhaParams.headSize = mHeadSize;
fmhaParams.headSizeV = mHeadSize;
if (mIsMLAEnabled)
if (mIsMLAEnabled && !mIsGenerationMLA)
{
// Context attention of MLA is different
fmhaParams.numKvHeads = mNumHeads;
Expand Down Expand Up @@ -2476,10 +2476,9 @@ int AttentionOp::initialize() noexcept
// Instantiate the mTllmGenFMHARunner used for MLA
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));
}
else
else if (mIsGenerationMLA && !mUseGenFlashMLA)
{
// Construct the fmha runner.
// FP8 Generation MLA also uses context FMHA.
// Construct the fmha runner for generation.
if (mFP8GenerationMLA)
{
data_type = DATA_TYPE_E4M3;
Expand Down Expand Up @@ -2524,13 +2523,17 @@ int AttentionOp::initialize() noexcept
"Deepseek should be supported by fmha in generation part.");
}
}

TLLM_CHECK_WITH_INFO(
mFmhaDispatcher->isSupported(), "Deepseek should be supported by fmha in context part.");
if (!mIsGenerationMLA)
{
TLLM_CHECK_WITH_INFO(
mFmhaDispatcher->isSupported(), "Deepseek should be supported by fmha in context part.");
}
}

// Fall back to unfused MHA kernels if not supported.
mEnableContextFMHA = mFmhaDispatcher->isSupported();
// Generation MLA reuses the context FMHA code path so set mEnableContextFMHA to true.
// However, do not check mFmhaDispatcher which is not used for generation MLA.
mEnableContextFMHA = mIsGenerationMLA || mFmhaDispatcher->isSupported();

// Only FMHA supports custom mask currently.
TLLM_CHECK_WITH_INFO(
Expand Down Expand Up @@ -2697,6 +2700,7 @@ std::string AttentionOp::toString() const
ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;
ss << "mSM: " << mSM << std::endl;
ss << "mUseTllmGen: " << mUseTllmGen << std::endl;
ss << "mIsGenerationMLA: " << std::boolalpha << mIsGenerationMLA << std::endl;
ss << "mUseGenFlashMLA: " << mUseGenFlashMLA << std::endl;
ss << "mMultiProcessorCount: " << mMultiProcessorCount << std::endl;
ss << "mMaxSharedMemoryPerBlockOptin: " << mMaxSharedMemoryPerBlockOptin << std::endl;
Expand Down
9 changes: 5 additions & 4 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ class AttentionOp
bool mSpecDecodingIsGenerationLengthVariable = false;
int32_t mSpecDecodingMaxGenerationLength = 1;
bool mIsMLAEnabled = false;
bool mIsGenerationMLA = false;
bool mUseGenFlashMLA = false;
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
int mCpSize = 1;
Expand Down Expand Up @@ -422,10 +423,10 @@ class AttentionOp
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
mIsSpecDecodingEnabled, mUseSpecDecoding, mSpecDecodingIsGenerationLengthVariable,
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank,
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(),
mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank,
mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode,
mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores);
};

private:
Expand Down
10 changes: 7 additions & 3 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,18 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
{
TLLM_CHECK(host_kv_cache_pool_mapping.has_value());
int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);

op->mIsMLAEnabled = true;
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
// only enable flash mla in the generation phase on sm90 and tokens_per_block == 64
op->mUseGenFlashMLA = tensorrt_llm::common::getSMVersion() == 90 && tokens_per_block == 64;
op->mMLAParams = {static_cast<int>(q_lora_rank.value()), static_cast<int>(kv_lora_rank.value()),
static_cast<int>(qk_nope_head_dim.value()), static_cast<int>(qk_rope_head_dim.value()),
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
static_cast<int>(layer_num)};

op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
// only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64
op->mUseGenFlashMLA = tensorrt_llm::common::getSMVersion() == 90 && tokens_per_block == 64;

// The following two parameters are used to compute kvcache related parameters such as kvcache block_size. So
// they need to be set to 1 and 512 + 64 for both context and generation. For MLA attention kernel configs,
// mNumKVHeads/mHeadSize are overwritten in common/attentionOp.cpp.
Expand Down