Skip to content

bugfix: import wrapper of mla decode #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2025

Conversation

dhy2000
Copy link
Contributor

@dhy2000 dhy2000 commented Apr 10, 2025

In tests/test_mla_decode_kernel.py, the code below uses BatchDecodeMlaWithPagedKVCacheWrapper:

wrapper = flashinfer.BatchDecodeMlaWithPagedKVCacheWrapper(
workspace_buffer,
use_cuda_graph=True,
use_tensor_cores=True,
paged_kv_indptr_buffer=kv_indptr,
paged_kv_indices_buffer=kv_indices,
paged_kv_last_page_len_buffer=kv_last_page_len,
)

However an AttributeError is thrown:

AttributeError: module 'flashinfer' has no attribute 'BatchDecodeMlaWithPagedKVCacheWrapper'. Did you mean: 'BatchDecodeWithPagedKVCacheWrapper'?

Because the module flashinfer does not import BatchDecodeMlaWithPagedKVCacheWrapper from submodule .decode and export it.

This patch can solve it.

@dhy2000
Copy link
Contributor Author

dhy2000 commented Apr 10, 2025

The failing jenkins check seems not caused by compilation error, the last several lines of the log:

[2025-04-10T07:30:54.197Z] [49/155] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /workspace/build/temp.linux-x86_64-cpython-312/workspace/csrc/generated/batch_ragged_prefill_head_qk_128_head_vo_128_posenc_0_fp16qkred_0_mask_1_dtypeq_bf16_dtypekv_e5m2_dtypeout_bf16_idtype_i32.o.d -I/workspace/include -I/workspace/3rdparty/cutlass/include -I/workspace/3rdparty/cutlass/tools/util/include -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/TH -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/TH -I/opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/opt/conda/envs/py312/include/python3.12 -c -c /workspace/csrc/generated/batch_ragged_prefill_head_qk_128_head_vo_128_posenc_0_fp16qkred_0_mask_1_dtypeq_bf16_dtypekv_e5m2_dtypeout_bf16_idtype_i32.cu -o /workspace/build/temp.linux-x86_64-cpython-312/workspace/csrc/generated/batch_ragged_prefill_head_qk_128_head_vo_128_posenc_0_fp16qkred_0_mask_1_dtypeq_bf16_dtypekv_e5m2_dtypeout_bf16_idtype_i32.o --expt-relaxed-constexpr -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --threads=1 -Xfatbin -compress-all -use_fast_math -DNDEBUG -DPy_LIMITED_API=0x03080000 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=flashinfer_kernels -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_90,code=sm_90

[2025-04-10T07:31:10.582Z] Cannot contact i-083df364dc7ab914e: hudson.remoting.ChannelClosedException: Channel "hudson.remoting.Channel@548be362:i-083df364dc7ab914e": Remote call on i-083df364dc7ab914e failed. The channel is closing down or has closed down

[2025-04-10T07:31:10.582Z] Could not connect to i-083df364dc7ab914e to send interrupt signal to process

Can this check re-run?

@yzh119
Copy link
Collaborator

yzh119 commented Apr 10, 2025

Hi @dhy2000 , we encourage using BatchMLAPagedAttentionWrapper (in https://docs.flashinfer.ai/api/mla.html), which supports both decode and append attention, instead of BatchDecodeMlaWithPagedKVCacheWrapper.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhy2000
Copy link
Contributor Author

dhy2000 commented Apr 13, 2025

Hi @dhy2000 , we encourage using BatchMLAPagedAttentionWrapper (in https://docs.flashinfer.ai/api/mla.html), which supports both decode and append attention, instead of BatchDecodeMlaWithPagedKVCacheWrapper.

Thanks, so is it necessary to keep the test_mla_decode_kernel.py which tests BatchDecodeMlaWithPagedKVCacheWrapper?

@dhy2000 dhy2000 force-pushed the fix_mla_decode_import branch from adef1b7 to 75b4f64 Compare April 13, 2025 07:40
@dhy2000
Copy link
Contributor Author

dhy2000 commented Apr 13, 2025

For the lint error (https://github.com/flashinfer-ai/flashinfer/actions/runs/14374651614/job/40334288048?pr=1013), Please format your code with pre-commit.

Lint fixed.

@dhy2000 dhy2000 force-pushed the fix_mla_decode_import branch from 75b4f64 to 96638ea Compare April 13, 2025 07:44
@yzh119
Copy link
Collaborator

yzh119 commented Apr 15, 2025

Thanks, so is it necessary to keep the test_mla_decode_kernel.py which tests BatchDecodeMlaWithPagedKVCacheWrapper?

We can keep it until fully deprecate this function.

@yzh119 yzh119 merged commit 32d44c0 into flashinfer-ai:main Apr 15, 2025
2 checks passed
@dhy2000 dhy2000 deleted the fix_mla_decode_import branch April 16, 2025 07:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants