Skip to content

PyTorch2FlyDSL: Improve translation performance for Attentions#315

Open
amd-wangfan wants to merge 3 commits into
mainfrom
feature/pytorch2flydsl-attns-opt
Open

PyTorch2FlyDSL: Improve translation performance for Attentions#315
amd-wangfan wants to merge 3 commits into
mainfrom
feature/pytorch2flydsl-attns-opt

Conversation

@amd-wangfan

@amd-wangfan amd-wangfan commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Summary

Extends the PyTorch→FlyDSL translation knowledge base to cover more details for Attention Kernels when translating. Also improves kernel category detection and pushes the agent to keep optimizing for performance.

What Changed

Knowledge base (skills/pytorch2flydsl-translation/)

  • Added guidance for Split-K GEMM (dynamic activations / small-M decode) vs
    fixed-weight preshuffle GEMM.
  • Added a new Decode Attention section (MLA & PagedAttention) with the key
    optimizations and anti-patterns.
  • Updated the op-mapping tables and decision trees accordingly.

Agent config (mini_kernel_pytorch_to_flydsl.yaml)

  • Added MLA-decode translation guidance.
  • Changed the flow to keep iterating and submit the best-performing correct
    version, not just the first one that passes.

Category detection (translation_registry.py)

  • Detect MLA / PagedAttention kernels and load the right KB docs.
  • Detect hand-written softmax(Q@K^T)@V attention so it loads the attention KB.

Test

  • Verified category detection returns the expected categories for MLA,
    PagedAttention, manual-softmax attention, and plain GEMM kernels.
  • Ran the existing test suite to confirm no regressions.
  • Ran end-to-end translations with the updated config/KB to confirm correctness
    and measured speedup.
  • Run a translation
    python -m minisweagent.run.preprocess.translate \
    --kernel-url /home/fanwang2/codes/Attention-Kernels/pytorch/Transformer/MultiHeadLatentAttention.py \
    --target-language flydsl \
    --repo /home/fanwang2/codes/Attention-Kernels \
    --flydsl-repo /home/fanwang2/codes/FlyDSL \
    -o /home/fanwang2/codes/Attention-Kernels/geak_trans_py2flydsl/mla_translate \
    --gpu 0
    

Result

Across 9 attention kernels, the optimized translations reach an average speedup of 8.62x (up from 3.25x before optimization), with a peak of 31.17x on PagedAttention. The new decode-attention work drives the biggest gains:
PagedAttention (1.13x → 31.17x) and MLA (3.80x → 19.94x).

Kernel SpeedUp (ori) SpeedUp (opt) Pytorch (ms) FlyDSL (ms)
BlockSparseAttention (BSA) 0.51x 0.97x 0.275 0.283
FlashAttention (FA) 19.33x 20.33x 8.454 0.416
GroupedQueryAttention (GQA) 1.27x 1.22x 0.431 0.352
MultiHeadAttention (MHA) 1.11x 1.11x 0.488 0.441
MultiHeadLatentAttention (MLA) 3.80x 19.94x 2.542 0.127
MultiQueryAttention (MQA) 1.14x 1.54x 0.448 0.291
PagedAttentionKVCache (PA) 1.13x 31.17x 4.538 0.146
ScaledDotProductAttention (SDPA) 0.47x 0.68x 0.051 0.075
LinearAttention (LA) 0.45x 0.58x 0.381 0.660
  • Highest speedup: 31.17x (PagedAttention)
  • Average speedup: 8.62x optimized vs 3.25x before optimization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants