Skip to content

[Example] Add MLA decode operator for DeepSeek Ascend NPU migration#989

Open
erhsh wants to merge 5 commits intotile-ai:ascendc_ptofrom
erhsh:ascendc_pto_mla
Open

[Example] Add MLA decode operator for DeepSeek Ascend NPU migration#989
erhsh wants to merge 5 commits intotile-ai:ascendc_ptofrom
erhsh:ascendc_pto_mla

Conversation

@erhsh
Copy link
Copy Markdown
Contributor

@erhsh erhsh commented May 8, 2026

  • Migrate MLA decode from GPU main repo to Ascend NPU
  • Support MQA with kv_head_num=1 (128 query heads share 1 KV head)
  • Integrate position encoding (Q_pe + K_pe) in attention computation
  • Implement online softmax with FP32 accumulation
  • Use Developer mode with automatic memory planning
  • Add workspace buffers for L0C->shared copy pattern
  • Default config: batch=1, heads=128, kv_ctx=8192, dim=512

- Migrate MLA decode from GPU main repo to Ascend NPU
- Support MQA with kv_head_num=1 (128 query heads share 1 KV head)
- Integrate position encoding (Q_pe + K_pe) in attention computation
- Implement online softmax with FP32 accumulation
- Use Developer mode with automatic memory planning
- Add workspace buffers for L0C->shared copy pattern
- Default config: batch=1, heads=128, kv_ctx=8192, dim=512
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a TileLang implementation of the DeepSeek MLA decode kernel optimized for Ascend NPUs, featuring a JIT-compiled flash attention function and a PyTorch-based reference for validation. Feedback focuses on improving kernel reusability for GQA by removing hardcoded indices, addressing performance bottlenecks caused by global memory round-trips, and optimizing normalization logic through broadcasted operations. Additionally, there is a request to align the kernel launch mode with the PR description and fix a PEP 8 style violation regarding the missing trailing newline.

workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype),
workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype),
):
with T.Kernel(batch * (heads // VALID_BLOCK_H), is_npu=True) as (cid, vid):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description mentions using Developer mode, but the code currently uses Normal mode (by calling T.Kernel without the threads argument). In Developer mode, you would typically specify threads=1 or threads=2 to explicitly manage core/thread indexing. If the intention was to use Developer mode for better control or memory planning, please update the kernel launch call.

Comment thread examples/deepseek_mla/example_mla_decode.py Outdated
Comment on lines +82 to +84
T.copy(acc_s_l0c, workspace_1[cid, :, :])

T.copy(workspace_1[cid, :, :], acc_s)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using workspace buffers for L0C -> Global Memory -> UB transfers is a significant performance bottleneck, especially inside the main loop. On Ascend NPU, this round-trip to global memory can be very expensive. If the current version of TileLang does not yet support direct L0C -> UB copies, this workaround is understandable, but it should be noted as a primary area for optimization once the backend supports direct transfers.

Comment on lines +113 to +114
for i, j in T.Parallel(VALID_BLOCK_H, dim):
acc_o[i, j] /= logsum[i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performing division in a T.Parallel loop for normalization might be less efficient than using a tile-level broadcasted operation. Consider checking if TileLang supports a more direct way to perform broadcasted element-wise division (e.g., T.tile.div with broadcasting) to better leverage the NPU's vector units.

parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline at the end, which violates PEP 8 standards.

References
  1. PEP 8: All source files should end with a single newline character. (link)

caojian5 added 4 commits May 8, 2026 17:54
- Fix kernel launch mode: Add threads=2 for proper Developer mode
- Fix cur_kv_head hardcoding: Use dynamic calculation for GQA reusability
- Improve normalization: Replace T.Parallel loop with broadcasted T.tile.div
- Add performance notes: Document workspace buffer bottleneck as workaround
- Address all medium priority issues from review
CI uses ruff format check (ci_ascend.yml), not yapf (format.sh).
Format changes:
- Function parameter indentation: 4 spaces (ruff) vs 8 spaces (yapf)
- Slice syntax spacing: keep spaces around : (ruff) vs remove (yapf)
- Power operator spacing: keep spaces around ** (ruff) vs remove (yapf)

This aligns with the actual CI workflow used in PR checks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant