[Example] Add MLA decode operator for DeepSeek Ascend NPU migration#989
[Example] Add MLA decode operator for DeepSeek Ascend NPU migration#989erhsh wants to merge 5 commits intotile-ai:ascendc_ptofrom
Conversation
erhsh
commented
May 8, 2026
- Migrate MLA decode from GPU main repo to Ascend NPU
- Support MQA with kv_head_num=1 (128 query heads share 1 KV head)
- Integrate position encoding (Q_pe + K_pe) in attention computation
- Implement online softmax with FP32 accumulation
- Use Developer mode with automatic memory planning
- Add workspace buffers for L0C->shared copy pattern
- Default config: batch=1, heads=128, kv_ctx=8192, dim=512
- Migrate MLA decode from GPU main repo to Ascend NPU - Support MQA with kv_head_num=1 (128 query heads share 1 KV head) - Integrate position encoding (Q_pe + K_pe) in attention computation - Implement online softmax with FP32 accumulation - Use Developer mode with automatic memory planning - Add workspace buffers for L0C->shared copy pattern - Default config: batch=1, heads=128, kv_ctx=8192, dim=512
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a TileLang implementation of the DeepSeek MLA decode kernel optimized for Ascend NPUs, featuring a JIT-compiled flash attention function and a PyTorch-based reference for validation. Feedback focuses on improving kernel reusability for GQA by removing hardcoded indices, addressing performance bottlenecks caused by global memory round-trips, and optimizing normalization logic through broadcasted operations. Additionally, there is a request to align the kernel launch mode with the PR description and fix a PEP 8 style violation regarding the missing trailing newline.
| workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), | ||
| workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), | ||
| ): | ||
| with T.Kernel(batch * (heads // VALID_BLOCK_H), is_npu=True) as (cid, vid): |
There was a problem hiding this comment.
The PR description mentions using Developer mode, but the code currently uses Normal mode (by calling T.Kernel without the threads argument). In Developer mode, you would typically specify threads=1 or threads=2 to explicitly manage core/thread indexing. If the intention was to use Developer mode for better control or memory planning, please update the kernel launch call.
| T.copy(acc_s_l0c, workspace_1[cid, :, :]) | ||
|
|
||
| T.copy(workspace_1[cid, :, :], acc_s) |
There was a problem hiding this comment.
Using workspace buffers for L0C -> Global Memory -> UB transfers is a significant performance bottleneck, especially inside the main loop. On Ascend NPU, this round-trip to global memory can be very expensive. If the current version of TileLang does not yet support direct L0C -> UB copies, this workaround is understandable, but it should be noted as a primary area for optimization once the backend supports direct transfers.
| for i, j in T.Parallel(VALID_BLOCK_H, dim): | ||
| acc_o[i, j] /= logsum[i] |
There was a problem hiding this comment.
Performing division in a T.Parallel loop for normalization might be less efficient than using a tile-level broadcasted operation. Consider checking if TileLang supports a more direct way to perform broadcasted element-wise division (e.g., T.tile.div with broadcasting) to better leverage the NPU's vector units.
| parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") | ||
| args = parser.parse_args() | ||
| batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim | ||
| main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) No newline at end of file |
There was a problem hiding this comment.
The file is missing a newline at the end, which violates PEP 8 standards.
References
- PEP 8: All source files should end with a single newline character. (link)
- Fix kernel launch mode: Add threads=2 for proper Developer mode - Fix cur_kv_head hardcoding: Use dynamic calculation for GQA reusability - Improve normalization: Replace T.Parallel loop with broadcasted T.tile.div - Add performance notes: Document workspace buffer bottleneck as workaround - Address all medium priority issues from review
CI uses ruff format check (ci_ascend.yml), not yapf (format.sh). Format changes: - Function parameter indentation: 4 spaces (ruff) vs 8 spaces (yapf) - Slice syntax spacing: keep spaces around : (ruff) vs remove (yapf) - Power operator spacing: keep spaces around ** (ruff) vs remove (yapf) This aligns with the actual CI workflow used in PR checks.