Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ agent:
3. You MUST preserve the exact `Model(nn.Module)` interface: same `__init__`, `forward` signature, same output shape and dtype.
4. You MUST preserve `get_inputs()` and `get_init_inputs()` functions.
5. The translated kernel MUST produce numerically identical results to the PyTorch original (within tolerance).
6. Use the `save_and_test` tool to validate your translation after writing it.
7. Every response must contain exactly one action.
6. You MUST preserve ALL conditional logic from the original PyTorch kernel, including branches that appear to be no-ops for the current test configuration.
7. Use the `save_and_test` tool to validate your translation after writing it.
8. Every response must contain exactly one action.

## FlyDSL Kernel Structure

Expand Down Expand Up @@ -60,10 +61,16 @@ agent:
- **MaxPool2d**: Custom @flyc.kernel with arith.maximumf over window elements. Do NOT fall back to F.max_pool2d.
- **BatchNorm2d**: Custom @flyc.kernel. In eval mode, pre-compute scale/shift in __init__. In training mode (default — harness does NOT call .eval()), compute batch mean/var dynamically: mean=x.mean(dim=(0,2,3)), var=x.var(dim=(0,2,3), unbiased=False), then scale=weight/sqrt(var+eps), shift=bias-mean*scale, apply per-channel affine. Do NOT use F.batch_norm, torch.batch_norm, or torch.ops.aten.batch_norm.
- **Bias addition after GEMM**: Use fused epilogue (epilogue="bias") in compile_preshuffle_gemm_a8 when possible. Otherwise translate to a simple @flyc.kernel (addf over the output and bias vectors).
- **MLA decode (paged latent attention)**: latent KV cache (`kv_cache` + `block_table`), asymmetric qk/v head dims.
If a prebuilt fused MLA kernel matches the source shape, wrap its launcher (do not generate a kernel) and do not decompose the attention core.
**MLA decomposed path** (no matching fused kernel): see `flydsl_translation_attention.md` § MLA Decode —
derive M_tot=B*H*Sq from source shapes; batch KV gather when T uniform; pre-scale Q;
dtype-matched softmax; get_default_kwargs(m,n,k); no .item() in hot loops. For long context,
use the page-tiled online softmax. **Matching fused kernel**: wrap/port it instead of decomposing.
- **Residual connections**: x = x + residual is a PyTorch elementwise add — translate to a simple addf @flyc.kernel.
- **Scalar broadcast ops**: x * scale, x / divisor, x + constant — translate to @flyc.kernel using arith.mulf/arith.divf/arith.addf with vector.broadcast.
- **CRITICAL anti-pattern — Python loops over batch/heads**: NEVER write for b in range(B) or for h in range(H) loops calling GEMM or any FlyDSL kernel per iteration. Instead: reshape all batch*head data into a single 2D tensor and call preshuffle GEMM once, or use flash attention.
- **Priority**: A correct translation with simple standalone kernels and zero fallbacks is ALWAYS better than a partially-fused translation that still has PyTorch fallbacks.
- **Priority**: A correct translation with simple standalone kernels and zero fallbacks is ALWAYS better than a partially-fused translation that still has PyTorch fallbacks — EXCEPT for MLA teaching kernels (nheads≠128): follow `flydsl_translation_attention.md` § MLA Decode (decomposed path); a "simple" per-batch loop at ~4× is worse than the batched optimized decomposed path at ~9×.
- **Complex models**: Use FlyDSL for ALL ops; Conv2d via im2col+GEMM, MaxPool2d via custom kernel, BatchNorm2d via custom kernel

CRITICAL: Do NOT use torch.mm, torch.bmm, torch.matmul, F.linear, nn.Linear, F.conv2d, nn.Conv2d, F.batch_norm, nn.BatchNorm2d, F.max_pool2d, F.scaled_dot_product_attention, torch.ops.aten.convolution, torch.ops.aten.batch_norm, torch.batch_norm, or any torch.ops.aten.* compute op. ALL of these have FlyDSL replacements.
Expand All @@ -90,10 +97,22 @@ agent:
- For Conv2d: im2col (F.unfold) + compile_preshuffle_gemm_a8 (fp16 cast) — no PyTorch compute fallback
- For MaxPool2d: custom @flyc.kernel with arith.maximumf — no F.max_pool2d fallback
- For BatchNorm2d: custom @flyc.kernel — in training mode (default), compute batch mean/var dynamically, then apply scale/shift. No F.batch_norm fallback
- For MLA: paged latent decode → `flydsl_translation_attention.md` § MLA Decode (wrap a matching fused kernel, else decomposed path)
3. Write the FlyDSL translation to the specified output path
4. Test it using the harness command provided
5. If the test fails, read the error output and fix the translation
6. Once the test passes, submit: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT`
6. Once the test PASSES, do NOT submit the first correct version for
performance-sensitive kernels. A correct but under-optimized kernel is not
done — keep iterating to raise the measured speedup:
- Record the current best harness latency. Each iteration, make ONE targeted
change, re-run the harness, and keep it only if latency improves AND
correctness still passes (otherwise revert to the best version).
- For the optimization levers and any performance target, follow the Knowledge
Base guidance for the detected category.
- Stop optimizing only when the speedup plateaus across consecutive iterations,
then keep the best version. Do not stop at the first passing result just
because it is "performant".
7. Submit the BEST-performing correct version: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT`

## Important

Expand Down
33 changes: 27 additions & 6 deletions src/minisweagent/skills/pytorch2flydsl-translation/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ This skill provides knowledge and strategy for translating PyTorch GPU kernels t

## Translation Strategy (in order of preference)

- **GEMM / Linear**: Use `compile_preshuffle_gemm_a8()` from `kernels.preshuffle_gemm`.
- **GEMM / Linear (fixed weight)**: Use `compile_preshuffle_gemm_a8()` from `kernels.preshuffle_gemm`.
CRITICAL: B-matrix must be preshuffled with `shuffle_weight(B.contiguous(), layout=(16, 16))` from `tests.utils`.
All tensor args must be `.view(-1)`. Scales: `torch.empty(0, device=dev, dtype=torch.float32)` for fp16.
- **GEMM (dynamic activations, small M)**: Use `hgemm_splitk_()` from `kernels.hgemm_splitk`.
For activation×activation matmuls (e.g. decomposed Q@K^T, attn@V, paged KV) where B is NOT a fixed weight.
`C = A @ B^T` with `A:(M,K)`, `B:(N,K)`. No preshuffle. See `flydsl_translation_gemm.md` section Split-K GEMM (hgemm_splitk).
- **Attention / SDPA**: ALWAYS use `build_flash_attn_func_module()` from `kernels.flash_attn_func`
when head_dim>=64, head_dim%32==0, seq_len%128==0. NEVER decompose attention into separate
GEMM+softmax+GEMM calls when flash attention fits — decomposed is 5-10x slower.
Expand All @@ -26,16 +29,34 @@ This skill provides knowledge and strategy for translating PyTorch GPU kernels t
- **Reductions** (sum, mean): Manual block reduction with wave shuffle
- **Conv/Pool/BatchNorm**: Use `torch.nn.functional` (ONLY ops with no FlyDSL equivalent)
- **Complex models**: Use FlyDSL for ALL ops except conv/pool/batchnorm

CRITICAL: Do NOT use torch.matmul, F.linear, nn.Linear, or F.scaled_dot_product_attention.
These ALL have FlyDSL pre-built replacements. PyTorch fallback is ONLY for Conv2d, MaxPool2d, BatchNorm2d.
- **Decode-mode attention** (`seqlen_q=1` with a paged KV cache: `kv_cache` or
`k_cache`+`v_cache`, plus `block_table`/`page_table` and `cache_seqlens`) — covers
MLA, PagedAttention, and any paged-decode kernel. See
`flydsl_translation_attention.md` § Decode Attention.
If a prebuilt fused kernel matches the source shape, wrap its launcher instead of
generating a kernel. Otherwise decompose with `hgemm_splitk_` + `build_softmax_module`
and apply the decode optimizations (these are what actually move FlyDSL kernel perf):
- Stack all rows into `M_tot = B*H*Sq` and run a single stacked softmax — never
loop per-(batch,head) calling GEMM/softmax one at a time.
- Batch the KV gather / paged-cache reconstruction and any GQA expansion into one
indexed gather, not a Python loop.
- Pre-scale Q in fp16 before the QK GEMM; keep softmax dtype matched to the GEMM output.
- Reuse persistent score/output buffers across calls; read `cache_seqlens` once via
`.tolist()` — never `.item()` in hot loops.
- Use `get_default_kwargs` for GEMM tiling; for long context use a page-tiled online softmax.

CRITICAL:
- Do NOT use torch.matmul, F.linear, nn.Linear, or F.scaled_dot_product_attention.
These ALL have FlyDSL pre-built replacements. PyTorch fallback is ONLY for Conv2d, MaxPool2d, BatchNorm2d.
- Do NOT wrap the kernel in a CUDA graph to obtain the speedup: graph capture only
removes host-side launch overhead and is not a FlyDSL kernel improvement.

## Reference Documentation

The `docs/` subdirectory contains detailed API references and translation guides:

- `flydsl_translation_api_reference.md` — FlyDSL compiler API, expression types, kernel patterns
- `flydsl_translation_guide.md` — PyTorch op mapping, structural patterns, common pitfalls
- `flydsl_translation_gemm.md` — GEMM/Linear translation with preshuffle_gemm
- `flydsl_translation_attention.md` — Attention/SDPA translation with flash_attn
- `flydsl_translation_gemm.md` — GEMM/Linear translation: preshuffle_gemm (fixed weight) + split-K hgemm (dynamic activations / small-M decode)
- `flydsl_translation_attention.md` — Attention/SDPA translation with flash_attn, plus decode-mode paged attention (MLA, PagedAttention; wrap a matching fused kernel, else decomposed path)
- `flydsl_translation_reductions.md` — Reduction ops (sum, mean, softmax, layernorm)
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ use `BLOCK_THREADS=256` and `VEC_WIDTH=8`.
| `torch.clamp(x, min=a)` | `arith.maximumf(val, a_const)` | Custom kernel |
| `torch.mean(x)` | Parallel reduction (see reduction patterns) | Custom kernel |
| `torch.softmax(x)` | `build_softmax_module()` | **Pre-built** |
| `torch.matmul(A, B)` | `compile_preshuffle_gemm_a8()` | **Pre-built** |
| `torch.matmul(A, B)` (fixed weight B) | `compile_preshuffle_gemm_a8()` | **Pre-built** |
| `torch.matmul` / `torch.bmm` (both activations, small M) | `hgemm_splitk_()` | **Pre-built** |
| `nn.Linear` | `compile_preshuffle_gemm_a8()` | **Pre-built** |
| `F.linear` | `compile_preshuffle_gemm_a8()` | **Pre-built** |
| `F.scaled_dot_product_attention` | `build_flash_attn_func_module()` | **Pre-built** |
Expand Down
Loading