diff --git a/src/minisweagent/run/preprocess/config/mini_kernel_pytorch_to_flydsl.yaml b/src/minisweagent/run/preprocess/config/mini_kernel_pytorch_to_flydsl.yaml index 4bdd9d4d7..c762cfc9b 100644 --- a/src/minisweagent/run/preprocess/config/mini_kernel_pytorch_to_flydsl.yaml +++ b/src/minisweagent/run/preprocess/config/mini_kernel_pytorch_to_flydsl.yaml @@ -23,8 +23,9 @@ agent: 3. You MUST preserve the exact `Model(nn.Module)` interface: same `__init__`, `forward` signature, same output shape and dtype. 4. You MUST preserve `get_inputs()` and `get_init_inputs()` functions. 5. The translated kernel MUST produce numerically identical results to the PyTorch original (within tolerance). - 6. Use the `save_and_test` tool to validate your translation after writing it. - 7. Every response must contain exactly one action. + 6. You MUST preserve ALL conditional logic from the original PyTorch kernel, including branches that appear to be no-ops for the current test configuration. + 7. Use the `save_and_test` tool to validate your translation after writing it. + 8. Every response must contain exactly one action. ## FlyDSL Kernel Structure @@ -60,10 +61,16 @@ agent: - **MaxPool2d**: Custom @flyc.kernel with arith.maximumf over window elements. Do NOT fall back to F.max_pool2d. - **BatchNorm2d**: Custom @flyc.kernel. In eval mode, pre-compute scale/shift in __init__. In training mode (default — harness does NOT call .eval()), compute batch mean/var dynamically: mean=x.mean(dim=(0,2,3)), var=x.var(dim=(0,2,3), unbiased=False), then scale=weight/sqrt(var+eps), shift=bias-mean*scale, apply per-channel affine. Do NOT use F.batch_norm, torch.batch_norm, or torch.ops.aten.batch_norm. - **Bias addition after GEMM**: Use fused epilogue (epilogue="bias") in compile_preshuffle_gemm_a8 when possible. Otherwise translate to a simple @flyc.kernel (addf over the output and bias vectors). + - **MLA decode (paged latent attention)**: latent KV cache (`kv_cache` + `block_table`), asymmetric qk/v head dims. + If a prebuilt fused MLA kernel matches the source shape, wrap its launcher (do not generate a kernel) and do not decompose the attention core. + **MLA decomposed path** (no matching fused kernel): see `flydsl_translation_attention.md` § MLA Decode — + derive M_tot=B*H*Sq from source shapes; batch KV gather when T uniform; pre-scale Q; + dtype-matched softmax; get_default_kwargs(m,n,k); no .item() in hot loops. For long context, + use the page-tiled online softmax. **Matching fused kernel**: wrap/port it instead of decomposing. - **Residual connections**: x = x + residual is a PyTorch elementwise add — translate to a simple addf @flyc.kernel. - **Scalar broadcast ops**: x * scale, x / divisor, x + constant — translate to @flyc.kernel using arith.mulf/arith.divf/arith.addf with vector.broadcast. - **CRITICAL anti-pattern — Python loops over batch/heads**: NEVER write for b in range(B) or for h in range(H) loops calling GEMM or any FlyDSL kernel per iteration. Instead: reshape all batch*head data into a single 2D tensor and call preshuffle GEMM once, or use flash attention. - - **Priority**: A correct translation with simple standalone kernels and zero fallbacks is ALWAYS better than a partially-fused translation that still has PyTorch fallbacks. + - **Priority**: A correct translation with simple standalone kernels and zero fallbacks is ALWAYS better than a partially-fused translation that still has PyTorch fallbacks — EXCEPT for MLA teaching kernels (nheads≠128): follow `flydsl_translation_attention.md` § MLA Decode (decomposed path); a "simple" per-batch loop at ~4× is worse than the batched optimized decomposed path at ~9×. - **Complex models**: Use FlyDSL for ALL ops; Conv2d via im2col+GEMM, MaxPool2d via custom kernel, BatchNorm2d via custom kernel CRITICAL: Do NOT use torch.mm, torch.bmm, torch.matmul, F.linear, nn.Linear, F.conv2d, nn.Conv2d, F.batch_norm, nn.BatchNorm2d, F.max_pool2d, F.scaled_dot_product_attention, torch.ops.aten.convolution, torch.ops.aten.batch_norm, torch.batch_norm, or any torch.ops.aten.* compute op. ALL of these have FlyDSL replacements. @@ -90,10 +97,22 @@ agent: - For Conv2d: im2col (F.unfold) + compile_preshuffle_gemm_a8 (fp16 cast) — no PyTorch compute fallback - For MaxPool2d: custom @flyc.kernel with arith.maximumf — no F.max_pool2d fallback - For BatchNorm2d: custom @flyc.kernel — in training mode (default), compute batch mean/var dynamically, then apply scale/shift. No F.batch_norm fallback + - For MLA: paged latent decode → `flydsl_translation_attention.md` § MLA Decode (wrap a matching fused kernel, else decomposed path) 3. Write the FlyDSL translation to the specified output path 4. Test it using the harness command provided 5. If the test fails, read the error output and fix the translation - 6. Once the test passes, submit: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT` + 6. Once the test PASSES, do NOT submit the first correct version for + performance-sensitive kernels. A correct but under-optimized kernel is not + done — keep iterating to raise the measured speedup: + - Record the current best harness latency. Each iteration, make ONE targeted + change, re-run the harness, and keep it only if latency improves AND + correctness still passes (otherwise revert to the best version). + - For the optimization levers and any performance target, follow the Knowledge + Base guidance for the detected category. + - Stop optimizing only when the speedup plateaus across consecutive iterations, + then keep the best version. Do not stop at the first passing result just + because it is "performant". + 7. Submit the BEST-performing correct version: `echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT` ## Important diff --git a/src/minisweagent/skills/pytorch2flydsl-translation/SKILL.md b/src/minisweagent/skills/pytorch2flydsl-translation/SKILL.md index 0f393c6c0..afd8469df 100644 --- a/src/minisweagent/skills/pytorch2flydsl-translation/SKILL.md +++ b/src/minisweagent/skills/pytorch2flydsl-translation/SKILL.md @@ -9,9 +9,12 @@ This skill provides knowledge and strategy for translating PyTorch GPU kernels t ## Translation Strategy (in order of preference) -- **GEMM / Linear**: Use `compile_preshuffle_gemm_a8()` from `kernels.preshuffle_gemm`. +- **GEMM / Linear (fixed weight)**: Use `compile_preshuffle_gemm_a8()` from `kernels.preshuffle_gemm`. CRITICAL: B-matrix must be preshuffled with `shuffle_weight(B.contiguous(), layout=(16, 16))` from `tests.utils`. All tensor args must be `.view(-1)`. Scales: `torch.empty(0, device=dev, dtype=torch.float32)` for fp16. +- **GEMM (dynamic activations, small M)**: Use `hgemm_splitk_()` from `kernels.hgemm_splitk`. + For activation×activation matmuls (e.g. decomposed Q@K^T, attn@V, paged KV) where B is NOT a fixed weight. + `C = A @ B^T` with `A:(M,K)`, `B:(N,K)`. No preshuffle. See `flydsl_translation_gemm.md` section Split-K GEMM (hgemm_splitk). - **Attention / SDPA**: ALWAYS use `build_flash_attn_func_module()` from `kernels.flash_attn_func` when head_dim>=64, head_dim%32==0, seq_len%128==0. NEVER decompose attention into separate GEMM+softmax+GEMM calls when flash attention fits — decomposed is 5-10x slower. @@ -26,9 +29,27 @@ This skill provides knowledge and strategy for translating PyTorch GPU kernels t - **Reductions** (sum, mean): Manual block reduction with wave shuffle - **Conv/Pool/BatchNorm**: Use `torch.nn.functional` (ONLY ops with no FlyDSL equivalent) - **Complex models**: Use FlyDSL for ALL ops except conv/pool/batchnorm - -CRITICAL: Do NOT use torch.matmul, F.linear, nn.Linear, or F.scaled_dot_product_attention. -These ALL have FlyDSL pre-built replacements. PyTorch fallback is ONLY for Conv2d, MaxPool2d, BatchNorm2d. +- **Decode-mode attention** (`seqlen_q=1` with a paged KV cache: `kv_cache` or + `k_cache`+`v_cache`, plus `block_table`/`page_table` and `cache_seqlens`) — covers + MLA, PagedAttention, and any paged-decode kernel. See + `flydsl_translation_attention.md` § Decode Attention. + If a prebuilt fused kernel matches the source shape, wrap its launcher instead of + generating a kernel. Otherwise decompose with `hgemm_splitk_` + `build_softmax_module` + and apply the decode optimizations (these are what actually move FlyDSL kernel perf): + - Stack all rows into `M_tot = B*H*Sq` and run a single stacked softmax — never + loop per-(batch,head) calling GEMM/softmax one at a time. + - Batch the KV gather / paged-cache reconstruction and any GQA expansion into one + indexed gather, not a Python loop. + - Pre-scale Q in fp16 before the QK GEMM; keep softmax dtype matched to the GEMM output. + - Reuse persistent score/output buffers across calls; read `cache_seqlens` once via + `.tolist()` — never `.item()` in hot loops. + - Use `get_default_kwargs` for GEMM tiling; for long context use a page-tiled online softmax. + +CRITICAL: + - Do NOT use torch.matmul, F.linear, nn.Linear, or F.scaled_dot_product_attention. + These ALL have FlyDSL pre-built replacements. PyTorch fallback is ONLY for Conv2d, MaxPool2d, BatchNorm2d. + - Do NOT wrap the kernel in a CUDA graph to obtain the speedup: graph capture only + removes host-side launch overhead and is not a FlyDSL kernel improvement. ## Reference Documentation @@ -36,6 +57,6 @@ The `docs/` subdirectory contains detailed API references and translation guides - `flydsl_translation_api_reference.md` — FlyDSL compiler API, expression types, kernel patterns - `flydsl_translation_guide.md` — PyTorch op mapping, structural patterns, common pitfalls -- `flydsl_translation_gemm.md` — GEMM/Linear translation with preshuffle_gemm -- `flydsl_translation_attention.md` — Attention/SDPA translation with flash_attn +- `flydsl_translation_gemm.md` — GEMM/Linear translation: preshuffle_gemm (fixed weight) + split-K hgemm (dynamic activations / small-M decode) +- `flydsl_translation_attention.md` — Attention/SDPA translation with flash_attn, plus decode-mode paged attention (MLA, PagedAttention; wrap a matching fused kernel, else decomposed path) - `flydsl_translation_reductions.md` — Reduction ops (sum, mean, softmax, layernorm) diff --git a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_api_reference.md b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_api_reference.md index 421305e5a..d6e00902e 100644 --- a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_api_reference.md +++ b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_api_reference.md @@ -408,7 +408,8 @@ use `BLOCK_THREADS=256` and `VEC_WIDTH=8`. | `torch.clamp(x, min=a)` | `arith.maximumf(val, a_const)` | Custom kernel | | `torch.mean(x)` | Parallel reduction (see reduction patterns) | Custom kernel | | `torch.softmax(x)` | `build_softmax_module()` | **Pre-built** | -| `torch.matmul(A, B)` | `compile_preshuffle_gemm_a8()` | **Pre-built** | +| `torch.matmul(A, B)` (fixed weight B) | `compile_preshuffle_gemm_a8()` | **Pre-built** | +| `torch.matmul` / `torch.bmm` (both activations, small M) | `hgemm_splitk_()` | **Pre-built** | | `nn.Linear` | `compile_preshuffle_gemm_a8()` | **Pre-built** | | `F.linear` | `compile_preshuffle_gemm_a8()` | **Pre-built** | | `F.scaled_dot_product_attention` | `build_flash_attn_func_module()` | **Pre-built** | diff --git a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_attention.md b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_attention.md index 288a7caf3..f48e5a02f 100644 --- a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_attention.md +++ b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_attention.md @@ -1,8 +1,8 @@ --- layer: "flydsl" category: "translation" -tags: ["flydsl", "translation", "attention", "transformer", "flash-attention"] -last_updated: 2026-03-23 +tags: ["flydsl", "translation", "attention", "transformer", "flash-attention", "mla", "decode", "paged-kv"] +last_updated: 2026-06-10 --- # FlyDSL Translation: Attention Patterns @@ -135,30 +135,55 @@ class Model(nn.Module): ### Strategy 3: Decomposed Attention with Pre-built Kernels -ONLY when padding is impractical (e.g., very large padding ratios), +ONLY when padding is impractical (e.g., very large padding ratios, paged KV), decompose into FlyDSL pre-built components. NEVER use `F.scaled_dot_product_attention`. +Use a **mixed strategy**: +- **`hgemm_splitk_`** for activation@activation matmuls (`Q@K^T`, `attn@V`). +- **`compile_preshuffle_gemm_a8`** for fixed-weight projections (`x@W_qkv`, `out@W_proj`). +- Do **not** preshuffle dynamic K/V cache tensors each forward. +- For paged decode (MLA latent cache or PagedAttention K/V cache), see § Decode Attention + below: wrap a matching prebuilt fused kernel when one exists, otherwise decompose (batch + KV gather, pre-scale Q, stacked `batch*nheads`, f16 softmax, `SPLIT_K=1` to start; + page-tiled online softmax for long context). Do not preshuffle dynamic K/V. + ```python +import torch +import torch.nn as nn +from kernels.hgemm_splitk import hgemm_splitk_ from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 from kernels.softmax_kernel import build_softmax_module from tests.utils import shuffle_weight class Model(nn.Module): + def __init__(self, n_embd): + super().__init__() + # Fixed-weight projection path -> preshuffle GEMM + self.w_qkv = nn.Parameter(torch.randn(3 * n_embd, n_embd, dtype=torch.float16)) + self.register_buffer("w_qkv_shuffled", shuffle_weight(self.w_qkv.data.contiguous(), layout=(16, 16))) + self.qkv_gemm = compile_preshuffle_gemm_a8( + M=0, N=3 * n_embd, K=n_embd, + tile_m=64, tile_n=128, tile_k=128, + in_dtype="fp16", out_dtype="fp16", lds_stage=2 + ) + def forward(self, q, k, v): - # QK^T via FlyDSL GEMM (after preshuffling K^T) - scores = ... # use compile_preshuffle_gemm_a8 + # (Optional) projection example: x -> qkv via preshuffle GEMM + # self.qkv_gemm(x_2d.view(-1), self.w_qkv_shuffled.view(-1), ..., stream=stream) + + # QK^T: C = Q @ K^T — q: (M, K), k: (N, K) with N=seq_len + hgemm_splitk_(scores, q_flat, k, hgemm_kwargs=self._gemm_kwargs, stream=stream) # Softmax via FlyDSL - scores_2d = scores.reshape(-1, N) - softmax_fn = build_softmax_module(scores_2d.shape[0], N, "f32") - attn_weights = torch.empty_like(scores_2d) - softmax_fn(scores_2d, attn_weights, scores_2d.shape[0], - stream=torch.cuda.current_stream()) - - # V projection via FlyDSL GEMM - return ... # use compile_preshuffle_gemm_a8 + self._softmax(scores, attn, M, stream=stream) + + # attn @ V: v_t = v.t() — (V_dim, seq_len) + hgemm_splitk_(out, attn, v_t, hgemm_kwargs=self._gemm_kwargs, stream=stream) + return out ``` +See `flydsl_translation_gemm.md` § Split-K GEMM for shapes, tile config, and MLA examples. + ## Causal Masking The FlyDSL flash attention kernel supports causal masking natively via `causal=True` @@ -253,7 +278,8 @@ are dynamic activations that change every forward pass. | Output projection (attn_out @ W_proj) | attn_out=activation, W=fixed weight | `compile_preshuffle_gemm_a8` (preshuffle W once) | | Q @ K^T (attention scores) | Q=activation, K=activation | `build_flash_attn_func_module` (handles Q@K^T + softmax + @V) | | att @ V (attention output) | att=activation, V=activation | `build_flash_attn_func_module` (part of flash attention) | -| Activation @ activation (no flash attn fit) | both dynamic | `torch.bmm` as fallback (acceptable for fp32 or non-standard shapes) | +| Activation @ activation (no flash attn fit) | both dynamic, fp16/bf16 | `hgemm_splitk_` — see `flydsl_translation_gemm.md` § Split-K GEMM | +| Activation @ activation (fp32 or rare shapes) | both dynamic | `torch.bmm` only if FlyDSL path unavailable | ### When torch.bmm is acceptable @@ -278,6 +304,125 @@ preshuffle_gemm(scores, Q, K_shuffled, ...) # defeats the purpose of preshuffli Preshuffling is a heavyweight operation designed to be done **once** at init. Calling it every forward pass adds overhead that far exceeds any GEMM speedup. +## Decode Attention (Paged: MLA & PagedAttention) + +Decode-mode attention reads a **paged KV cache** through a `block_table` / +`page_table` with `seqlen_q == 1`. This covers MLA (latent cache) and PagedAttention +(standard / GQA KV cache); both share the same decomposed FlyDSL strategy and +optimizations. This is **not** standard BSHD flash attention — do **not** use +`build_flash_attn_func_module` for it. + +- **MLA** reads a latent cache where each row is a single compressed vector: K uses the + full `headdim_qk`, and V is the **leading `headdim_v` slice of that same row** + (`headdim_v <= headdim_qk`). Signals: a `MultiHeadLatentAttention`-style module, + `kv_cache` + `block_table` + `cache_seqlens`, asymmetric `headdim_qk` / `headdim_v`. +- **PagedAttention** reads separate `k_cache`/`v_cache` with symmetric `headdim` and may + use GQA (`nheads_q % nheads_kv == 0`). Reconstruct the cache and expand KV heads by the + group count in a single batched gather before the GEMMs. + +Define shape symbols from the source (not a fixed benchmark): `B`=batch, `Sq`=seqlen_q, +`H`=nheads, `Dqk`/`Dv`=QK/V head dims (equal for PagedAttention), `T`=cache length, +`M_row = H*Sq`, `M_tot = B*M_row`. + +**Do NOT wrap the kernel in a CUDA graph to report the speedup.** Graph capture only +removes host-side launch overhead — it is not a FlyDSL kernel improvement, and it makes +the comparison against the (non-captured) PyTorch baseline misleading. Measure the FlyDSL +kernel WITHOUT CUDA graphs so the uplift reflects the translation itself. The real +kernel-level wins come from the optimizations below. + +### Strategy + +1. **Reuse a prebuilt fused MLA kernel when one matches the shape.** Fused MLA + kernels bake head count, head dims, page size, and dtype in at compile time, so + they only apply when the source shape matches. When one fits, wrap its launcher + (build/cache any metadata buffers in `__init__`) instead of writing a kernel. +2. **Otherwise decompose** with `hgemm_splitk_` (for `Q@K^T` and `attn@V`) + + `build_softmax_module`, applying the optimizations below. +3. **For long context / memory pressure**, use a **page-tiled online softmax** so the + full `(rows, seq)` attention matrix is never materialized (see below). + +### Decomposed path: key optimizations + +Decode (`Sq == 1`) MLA is **memory-bound** and **launch-bound** — `M_tot = B*M_row` +is tiny (e.g. 64 rows), so runtime is dominated by the **number of kernel launches** +and KV-cache bytes read, not FLOPs. A naive per-batch decomposition stalls at a few×; +the items below are what take it to ~10×. The first three are the highest-leverage. + +| Optimization | What to do | +|--------------|------------| +| **Stack the softmax (biggest win)** | Write every batch's QK result into **one** `(M_tot, T)` scores buffer and run a **single** `build_softmax_module(M_tot, T)` over all `B*M_row` rows. Even when a per-batch `block_table` forces per-batch `Q@K^T` / `attn@V` GEMMs, the softmax must stay one stacked launch — **never** call `build_softmax_module(M_row, T)` once per batch inside the loop. | +| **Gather + transpose outside the loop** | When `cache_seqlens` are uniform, gather all KV once via `block_table` → `(B, T, Dqk)`, and build the batched `V^T` → `(B, Dv, T)` **before** the GEMM loop. Never `.t().contiguous()` / `.contiguous()` per batch inside the loop. | +| **Fewer launches** | Aim for ~`2B + 1` launches (per-batch QK + 1 stacked softmax + per-batch PV), not `~3B` + extra copies. Per-batch loops are only acceptable for the QK/PV GEMMs (because K differs per request), not for softmax or layout ops. | +| **No host sync** | Index `block_table` / `cache_seqlens` on the GPU; never `.item()` / `.tolist()` inside the loop. | +| **Pre-scale Q once** | Fold `1/sqrt(Dqk)` into Q before `Q@K^T`; avoid per-batch `scores.float().mul_(scale)`. | +| **Right output dtype + no layout copy** | Allocate the output in the GEMM dtype and write GEMM results straight into it; avoid a trailing `output.to(dtype)` cast. When `Sq == 1` the `(B, M_row, Dv)` GEMM output is already the `(B, Sq, H, Dv)` layout — reshape, don't `transpose(...).contiguous()`. | +| **Match dtypes** | Use an f16/bf16 `build_softmax_module` to match `hgemm_splitk_` I/O; reserve fp32 accumulation for the online path. | +| **Tune tiles from shapes** | Derive `(M, N, K)` per GEMM (QK: `(M_tot, T, Dqk)`, PV: `(M_row, Dv, T)`); start `SPLIT_K=1` at small M and profile up. When `N` (=`T` or `Dv`) is a multiple of 128, prefer `TILE_N=128` over 64 — this is typically the ~7×→~10× jump for decode. | +| **Reuse buffers** | One `(M, seq)` buffer for scores→attn (in-place softmax); skip the mask entirely when `Sq == 1`. | +| **Lazy-compile** | Cache compiled softmax / GEMM modules keyed by shape to avoid re-JIT every forward. | + +Concrete decode skeleton (uniform `T`; this stacked structure is what reaches ~10×): + +```python +# Sq == 1, M_row = H, M_tot = B*M_row +q_scaled = (q.half() * scale).reshape(B, M_row, Dqk) # pre-scale once +kv = gather_pages(kv_cache, block_table, T) # (B, T, Dqk) — one gather +vt = kv[:, :, :Dv].transpose(1, 2).contiguous() # (B, Dv, T) — batched once + +scores = empty(M_tot, T, fp16) +for b in range(B): # per-batch QK only (K differs) + hgemm_splitk_(scores[b*M_row:(b+1)*M_row], q_scaled[b], kv[b], hgemm_kwargs=qk_kwargs) + +softmax_fn(scores, attn, M_tot) # ONE stacked softmax launch + +for b in range(B): # per-batch PV only + hgemm_splitk_(out[b*M_row:(b+1)*M_row], attn[b*M_row:(b+1)*M_row], vt[b], hgemm_kwargs=pv_kwargs) +``` + +For prefill (`Sq > 1`), `M_row = H*Sq` grows and the causal mask is live; raise +`TILE_M` and re-profile `SPLIT_K` for the larger row count. + +### Page-tiled online softmax (long context) + +Iterate the KV cache in page-sized tiles and keep a running softmax state per row, so +no full-sequence attention buffer is allocated: + +```python +m = full((M, 1), -inf, f32) # running max +l = zeros((M, 1), f32) # running exp-sum +o = zeros((M, Dv), f32) # running output + +for each page tile t: + K_t, V_t = load_tile_from_paged_cache(...) + scores_t = hgemm_splitk(Q, K_t) * scale + m_new = maximum(m, rowmax(scores_t)) + alpha = exp(m - m_new) + p_t = exp(scores_t - m_new) + l = alpha * l + rowsum(p_t) + o = alpha * o + hgemm_splitk(p_t, V_t^T) + m = m_new + +out = o / l +``` + +Keep `m/l/o` in fp32 here. A single fused `@flyc.kernel` that runs this loop +internally (page loop + online `m/l/o` + PV) is usually the largest win for decode, +since it collapses the whole attention into one launch. + +### Anti-patterns + +1. Using flash attention (`build_flash_attn_func_module`) for paged MLA — wrong API and layout. +2. Per-batch **softmax** or **layout** ops inside the loop (`for b: build_softmax_module(M_row, T)`, `for b: v.t().contiguous()`) — stack the softmax into one `(M_tot, T)` launch and batch the gather/transpose outside the loop. (Per-batch QK/PV GEMMs are fine because K differs per request; softmax and layout are not.) +3. Host syncs in `forward` (`.item()`, `.tolist()`) that serialize host↔device every step. +4. Preshuffling K/V with `shuffle_weight` every forward — preshuffle is for fixed weights only. +5. Standalone scale / mask passes and double KV copies instead of folding them in. +6. f32 softmax feeding an f16 GEMM in the decomposed path (avoidable dtype cast). +7. Forcing a shape-specialized fused kernel onto a different shape by editing its baked-in constants — that is a different kernel; decompose or write a shape-appropriate kernel instead. + +### Reference + +- Split-K GEMM shapes and tile config for the decomposed path: `flydsl_translation_gemm.md` § Split-K GEMM. + ## Decision Summary ``` @@ -290,10 +435,17 @@ Matmul type? │ │ └── build_flash_attn_func_module() [NO F.scaled_dot_product_attention] │ ├── Non-standard dims │ │ └── Pad Q/K/V, run flash attention, slice back +│ ├── Paged decode (seqlen_q=1: MLA kv_cache or PagedAttention k/v_cache + block_table) +│ │ ├── Prebuilt fused kernel matches the shape → wrap its launcher [see § Decode Attention] +│ │ └── No match → decompose: hgemm_splitk_ + build_softmax_module [see § Decode Attention] +│ ├── Flash infeasible (paged KV, non-BSHD) +│ │ ├── Baseline: hgemm_splitk_ + build_softmax_module +│ │ └── Preferred: page-tiled online softmax (`m/l/o` + tile hgemm_splitk_) +│ │ [see § Decode Attention; NO shuffle_weight on K/V] │ └── Non-softmax attention (e.g., ReLU-attention) -│ └── torch.bmm is acceptable for activation-activation matmuls +│ └── hgemm_splitk_ or torch.bmm for activation-activation matmuls ├── Activation @ activation (non-attention, both sides dynamic) -│ └── torch.bmm as fallback (DO NOT preshuffle activations) +│ └── hgemm_splitk_ (fp16/bf16); torch.bmm only if FlyDSL unavailable └── Causal masking └── FlyDSL flash attention supports causal=True natively ``` diff --git a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_gemm.md b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_gemm.md index 79bf6b04d..e408f3971 100644 --- a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_gemm.md +++ b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_gemm.md @@ -1,8 +1,8 @@ --- layer: "flydsl" category: "translation" -tags: ["flydsl", "translation", "gemm", "matmul", "linear"] -last_updated: 2026-03-23 +tags: ["flydsl", "translation", "gemm", "matmul", "linear", "hgemm", "splitk", "decode"] +last_updated: 2026-06-09 --- # FlyDSL Translation: GEMM / Matrix Multiplication @@ -107,19 +107,25 @@ as separate operations after the GEMM: - **Fused bias+activation**: write a single `@flyc.kernel` that computes `output = max(0, gemm_output + bias)` in one pass -### Alternative: hgemm_splitk (FP16 SplitK GEMM) +### Alternative: hgemm_splitk (FP16/BF16 Split-K GEMM) -For small M (e.g., batch_size=1 decode), standard tile configs may not -fill the GPU. `hgemm_splitk` splits the K dimension across thread blocks: +For **dynamic activation × activation** matmuls (especially **small M**, e.g. +decode with `M = seqlen_q * num_heads`), use `hgemm_splitk_` instead of +preshuffle GEMM. It does **not** require `shuffle_weight`; B can change every +forward (paged KV, attention scores). -```python -from kernels.hgemm_splitk import compile_hgemm_splitk -``` +Use when: +- Both operands are activations (not a fixed weight to preshuffle once) +- `M` is small and standard preshuffle `tile_m` under-fills the GPU +- Flash attention does not apply (paged cache, MLA, non-BSHD layout) -Use when M < tile_m and standard GEMM underperforms. Only available in -newer FlyDSL versions — check availability before using. +**Full API, constraints, tile guide, and attention examples** are documented in +the [§ Split-K GEMM (hgemm_splitk)](#split-k-gemm-hgemm_splitk-dynamic-activations--small-m-decode) +section below. -### Constraints +### Constraints (preshuffle GEMM only) + +The following apply to `compile_preshuffle_gemm_a8`, **not** to `hgemm_splitk_`: - `tile_k * elem_bytes` must be divisible by 64 - `M` and `N` can be 0 (dynamic) — resolved at launch time @@ -180,6 +186,195 @@ def get_init_inputs(): return [4096, 4096] ``` +## Split-K GEMM (hgemm_splitk): Dynamic Activations / Small-M Decode + +Use `hgemm_splitk` from `kernels.hgemm_splitk` when **both operands are dynamic +activations** (not fixed weights) and preshuffle GEMM does not apply. + +| Scenario | Use | +|----------|-----| +| `nn.Linear`, fixed weight `W` | Preshuffle GEMM (`compile_preshuffle_gemm_a8` + `shuffle_weight`, once in `__init__`) — see above | +| Standard SDPA (contiguous Q/K/V, head_dim/seq constraints) | `build_flash_attn_func_module()` | +| **Activation @ activation**, small **M** (decode, few rows) | **`hgemm_splitk_`** | +| **Activation @ activation**, both sides change every forward (paged KV, attention scores) | **`hgemm_splitk_`** | +| Large M, static B weight | Preshuffle GEMM | + +**Do NOT** call `shuffle_weight` on K/V every forward pass to force preshuffle GEMM. +Preshuffle is weight-stationary; per-forward shuffling defeats its purpose. + +Typical shapes: decode attention (`seqlen_q=1`, `M = seqlen_q * num_heads` small), +MLA with paged KV cache, batched matmul where B varies per batch element. + +**Exception:** For paged decode (MLA latent cache with asymmetric qk/v dims, or +PagedAttention k/v cache), see `flydsl_translation_attention.md` § Decode Attention — wrap +a matching prebuilt fused kernel when one exists, otherwise use the decomposed split-K +path described there. + +### Math and Layout + +The kernel computes: + +``` +C = A @ B^T +``` + +| Tensor | Shape | Role | +|--------|-------|------| +| `A` (`a`) | `(M, K)` | Left operand (e.g. Q flattened over heads) | +| `B` (`b`) | `(N, K)` | Right operand stored row-major as `(N, K)` — **not** transposed | +| `C` (`c`) | `(M, N)` | Output (pre-allocated) | + +Equivalent PyTorch: `torch.mm(A, B.T)` or `A @ B.transpose(-2, -1)` when `B` is `(N, K)`. + +For **Q @ K^T** with `K` of shape `(seq_len, K_dim)`, pass `B = K` (already `(N, K)`). + +For **attn @ V** with `V` of shape `(seq_len, V_dim)`, transpose first: + +```python +vt = v.t() # (V_dim, seq_len) — here N=V_dim, K=seq_len +hgemm_splitk_(out, attn, vt, hgemm_kwargs=kwargs, stream=stream) +``` + +### High-Level API (Preferred) + +```python +from kernels.hgemm_splitk import hgemm_splitk_ + +# C, A, B: fp16 or bf16, CUDA. Shapes as above. +hgemm_splitk_( + c, # (M, N) output, pre-allocated + a, # (M, K) + b, # (N, K) + bias=None, # optional (N,) — rarely used in translations + hgemm_kwargs={...}, # tile config; see below + stream=torch.cuda.current_stream(), +) +``` + +- JIT-compiles on first call for each `(dtype, N, K, **hgemm_kwargs)` tuple (cached). +- `M` is dynamic at launch time; **`N` and `K` are fixed at compile time** (from `b.shape`). +- No preshuffling, no scale tensors, no `.view(-1)` requirement (internally reshapes to 2D). +- `get_default_kwargs(m, n, k)` supplies tuned tiles for common LLM shapes; override via `hgemm_kwargs`. + +### Low-Level API + +For repeated launches with the same `(N, K)` and tile config, compile once: + +```python +from kernels.hgemm_splitk import compile_hgemm_kernel, get_semaphore + +launch_fn = compile_hgemm_kernel( + "f16", # or "bf16" + n=N, k=K, # fixed at compile time + TILE_M=16, TILE_N=128, TILE_K=64, + SPLIT_K=1, + BLOCK_M_WARPS=1, BLOCK_N_WARPS=2, BLOCK_K_WARPS=1, + B_TO_LDS=True, + HAS_BIAS=False, +) +semaphore, signal = get_semaphore(stream, device) +launch_fn(c, a, b, bias_placeholder, m, semaphore, signal, stream=stream) +``` + +Prefer `hgemm_splitk_()` unless you need explicit compile caching control. + +### Tile Configuration (`hgemm_kwargs`) + +| Key | Meaning | +|-----|---------| +| `TILE_M` | M-dimension tile (16 for small decode M) | +| `TILE_N` | N-dimension tile; **`N` must be divisible by `TILE_N`** | +| `TILE_K` | K-dimension tile; **`K` must be divisible by `TILE_K * SPLIT_K` logic** | +| `SPLIT_K` | Split K across blocks (>1 improves occupancy for large K, small M) | +| `BLOCK_M_WARPS` | Warps along M (product with N/K warps ≤ 8) | +| `BLOCK_N_WARPS` | Warps along N | +| `BLOCK_K_WARPS` | Warps along K (K-slicing within block) | +| `B_TO_LDS` | Stage B matrix in LDS (often `True`) | + +#### Recommended starting points + +| M range | `TILE_M` | Notes | +|---------|----------|-------| +| 1–16 | 16 | Decode / few query rows | +| 17–64 | 32 or 64 | Medium batch | +| 64+ | 64–128 | May still use splitk; compare vs preshuffle if B is static | + +`TILE_N`: 64 or 128 (must divide `N`). `TILE_K`: 64 or 128/256 for large K. + +**Decode (small M, large K): try `SPLIT_K > 1` when profiling shows benefit.** With +`M ≈ 16` (single batch slice) and `K ≥ 512`, split-K can raise occupancy. With +**stacked** `M = batch * nheads` (e.g. 64) in decomposed MLA teaching kernels, +**start with `SPLIT_K=1`** — semaphore sync often dominates at tiny per-launch M. +See `flydsl_translation_attention.md` § Decode Attention (decomposed path). + +Example (single-batch MLA decode, `M=16`, `N=512`, `K=576` — profile both): + +```python +gemm_kwargs = { + "TILE_M": 16, "TILE_N": 128, "TILE_K": 64, + "SPLIT_K": 2, # >1 for small-M / large-K decode occupancy + "BLOCK_M_WARPS": 1, "BLOCK_N_WARPS": 2, + "BLOCK_K_WARPS": 1, "B_TO_LDS": True, +} +``` + +### Constraints (split-K GEMM only) + +- **Dtypes**: `torch.float16` or `torch.bfloat16` only. +- **`N % TILE_N == 0`** (compile-time `n` from `b.shape[0]`). +- **`K`**: must satisfy divisibility checks in `compile_hgemm_kernel` (`K % TILE_K`, split-K splits, etc.). +- **`M`**: dynamic; partial final M-tile handled in kernel. +- **GPU arch**: tested on `gfx942`, `gfx950` (see FlyDSL `test_hgemm_splitk.py`). +- **No preshuffle**: unlike `compile_preshuffle_gemm_a8`, B is used as-is. +- When `SPLIT_K > 1`, internal semaphore buffer size limits grid (`bm * bn`). + +### Decomposed Attention Pattern + +Use when flash attention does not apply (paged KV, non-standard head dims, MLA, etc.): + +```python +from kernels.hgemm_splitk import hgemm_splitk_ +from kernels.softmax_kernel import build_softmax_module + +# 1) scores = Q @ K^T — A: (M, K_qk), B: K as (seq, K_qk) -> (N=seq, K=K_qk) +hgemm_splitk_(scores, q_flat, k, hgemm_kwargs=gemm_kwargs, stream=stream) + +# 2) scale + mask (element-wise / PyTorch structural ops) + +# 3) softmax +softmax_fn(scores, attn, M, stream=stream) + +# 4) out = attn @ V — V^T as (N=v_dim, K=seq) +vt = v.t().contiguous() +hgemm_splitk_(out, attn, vt, hgemm_kwargs=gemm_kwargs, stream=stream) +``` + +Store `gemm_kwargs` on the module; compile once per `(N, K)` shape if `seq_len` is bounded +(use `max_seq_len` buffers and slice, as in MLA translations). + +### Preshuffle GEMM vs Split-K GEMM + +| | Preshuffle GEMM | hgemm_splitk | +|--|-----------------|--------------| +| B operand | Fixed weight, shuffled once | Any `(N, K)` tensor, no shuffle | +| Best for | Linear, conv GEMM | Dynamic activations, small M | +| M | Dynamic (`M=0` in compile) | Dynamic at launch | +| N, K | Dynamic at launch | **Fixed at compile** from `b.shape` | +| Scales | Required (`empty(0)` for fp16) | Not used | +| Launch args | All `.view(-1)` | 2D tensors OK | + +### Split-K Pitfalls + +1. **Wrong B layout**: Pass `(N, K)`, not `(K, N)`, unless you explicitly transpose to match `C = A @ B^T`. +2. **Recompile every forward**: Changing `N` or `K` (e.g. growing `seq_len` past compile bounds) triggers new JIT. Pre-allocate for `max_seq_len` and slice. +3. **Using preshuffle for K/V**: Do not `shuffle_weight` on cache tensors each step. +4. **Large M + static weight**: Use preshuffle GEMM instead for better throughput. +5. **Flash-eligible SDPA**: If Q/K/V are contiguous BSHD and constraints hold, flash attention beats decomposed splitk. + +### Split-K Reference Implementations + +- FlyDSL tests: `FlyDSL/tests/kernels/test_hgemm_splitk.py` + ## GEMM + Reduction Fusion: Replace GEMM with Custom Kernel When a GEMM is **immediately followed by a reduction** (e.g., `sum`, `mean`), the @@ -269,10 +464,9 @@ handles all GEMM operations. Do NOT use `torch.mm` for fp32 GEMM. handles the full Q@K^T → softmax → @V pipeline natively. - **Shared B-matrix**: reshape `(B, M, K)` to `(B*M, K)`, use single `compile_preshuffle_gemm_a8`, then reshape back. B-matrix is preshuffled once. -- **Varying B-matrix per batch**: reshape both operands so the batch is folded into the - M dimension. For `(B, M, K) @ (B, K, N)`: iterate over batch elements calling - preshuffle GEMM per element (each B-slice is preshuffled separately). - This is acceptable when flash attention does not apply. +- **Varying B-matrix per batch (activations)**: fold batch into M and use `hgemm_splitk_` + (no preshuffle). See § Split-K GEMM (hgemm_splitk) above. + Use preshuffle GEMM only when each B-slice is a **static weight** shuffled once. ### Conv2d internal GEMM diff --git a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_guide.md b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_guide.md index 23106995c..d4f7884ea 100644 --- a/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_guide.md +++ b/src/minisweagent/skills/pytorch2flydsl-translation/docs/flydsl_translation_guide.md @@ -27,7 +27,8 @@ Unlike CUDA/Triton where you use `input[idx]`, FlyDSL uses layout algebra: Identify the computational pattern: 1. **Element-wise**: Each output depends only on corresponding input(s) → custom `@flyc.kernel` 2. **Reduction**: Output has fewer elements (sum, mean, softmax) → manual reduction or pre-built kernel -3. **GEMM/Linear**: Matrix multiplication → `compile_preshuffle_gemm_a8` + `shuffle_weight` +3. **GEMM/Linear**: Fixed-weight matmul → `compile_preshuffle_gemm_a8` + `shuffle_weight`; + dynamic activation matmul (small M, decode) → `hgemm_splitk_` (see `flydsl_translation_gemm.md` § Split-K GEMM) 4. **Normalization**: LayerNorm, RMSNorm → `build_layernorm_module` / `build_rmsnorm_module` 5. **Convolution**: Conv2d → im2col (`F.unfold`) + `compile_preshuffle_gemm_a8` (fp16 cast) @@ -59,7 +60,8 @@ See `flydsl_translation_conv_pool_bn.md` for the complete worked example. ### Pre-built FlyDSL kernels -- **GEMM**: `compile_preshuffle_gemm_a8` — replaces `nn.Linear`, `torch.matmul`, `F.linear` +- **GEMM (fixed weight)**: `compile_preshuffle_gemm_a8` — replaces `nn.Linear`, `F.linear` +- **GEMM (dynamic activations)**: `hgemm_splitk_` — Q@K^T, attn@V when B is not preshufflable - **Flash Attention**: `build_flash_attn_func_module` — replaces `F.scaled_dot_product_attention` - **Softmax**: `build_softmax_module` - **LayerNorm/RMSNorm**: `build_layernorm_module` / `build_rmsnorm_module` @@ -354,17 +356,20 @@ What operation type? ├── LayerNorm / RMSNorm │ └── Use build_layernorm_module() / build_rmsnorm_module() ├── GEMM / Linear / torch.matmul -│ ├── fp32 inputs? Cast to fp16, use compile_preshuffle_gemm_a8() [NO torch.mm] -│ └── fp16/bf16 → Use compile_preshuffle_gemm_a8() [NO torch.matmul / F.linear] +│ ├── Fixed weight B → compile_preshuffle_gemm_a8() + shuffle_weight once [NO nn.Linear] +│ ├── Both operands dynamic (activations) → hgemm_splitk_() [NO shuffle_weight per forward] +│ └── fp32 inputs? Cast to fp16/bf16 before FlyDSL GEMM [NO torch.mm] ├── Batched matmul (torch.bmm) -│ ├── Attention pattern (Q@K^T) → Use build_flash_attn_func_module() -│ ├── Shared B-matrix → reshape (B,M,K) to (B*M,K), single preshuffle GEMM -│ └── Varying B → reshape batch into M dim, preshuffle GEMM [NO torch.bmm] +│ ├── Attention pattern (Q@K^T), flash fits → build_flash_attn_func_module() +│ ├── Shared static B-matrix → reshape (B,M,K) to (B*M,K), single preshuffle GEMM +│ └── Varying dynamic B → hgemm_splitk_ per slice or folded M [NO preshuffle on activations] ├── Attention (self-attention, SDPA, Flash) │ ├── Constraints met → Use build_flash_attn_func_module() │ ├── head_dim not %32 or <64 → Pad Q/K/V to next valid head_dim, flash attn, slice back │ ├── seq_len not %128 → Pad Q/K/V along seq dim, flash attn, slice back -│ └── Both unmet → Decompose into preshuffle GEMM + build_softmax_module [NO F.scaled_dot_product_attention] +│ └── Flash / fused decode infeasible → see flydsl_translation_attention.md § Decode Attention +│ (nheads_q=16 decomposed path): batch gather, pre-scale Q, stacked M, f16 softmax +│ then online/fused roadmap for higher tiers ├── Conv2d │ ├── F.unfold (im2col) to get patches (B, K_patch, L) │ ├── Transpose+reshape to (B*L, K_patch) = A matrix diff --git a/src/minisweagent/tools/translation_registry.py b/src/minisweagent/tools/translation_registry.py index 8b5c2a3de..d1cc9f640 100644 --- a/src/minisweagent/tools/translation_registry.py +++ b/src/minisweagent/tools/translation_registry.py @@ -17,8 +17,6 @@ from dataclasses import dataclass, field from pathlib import Path -from minisweagent import get_data_dir - logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -119,6 +117,58 @@ def _detect_pytorch_module(kernel_path: Path) -> bool: } +def _is_paged_decode_attention(text: str, source_path: Path) -> bool: + """Detect paged-decode attention kernels (MLA & PagedAttention) for KB loading. + + These ``seqlen_q == 1`` paged-KV kernels are written with ``torch.matmul`` + + paged-cache indexing (not SDPA/MHA/``@``-transpose), so they slip past the + generic ``attention`` regexes. Detect them explicitly so the attention KB + (which holds § Decode Attention) and the gemm KB both load. Matches explicit + naming (docstrings, classes, fused APIs), the structural paged-cache combo + (paged table + cache seqlens + a KV cache), and filename stems such as + ``MultiHeadLatentAttention.py`` / ``PagedAttentionKVCache.py``. + """ + if re.search( + r"MultiHeadLatent|Multi-head\s+Latent|multihead\s+latent|" + r"PagedAttention|Paged\s+(KV\s+Cache\s+)?Attention", + text, + re.IGNORECASE, + ): + return True + if re.search( + r"mla_fwd_decode|flydsl_mla_fwd_decode|get_mla_metadata|mla_reduce", + text, + ): + return True + # Structural combo: a paged block/page table + cache seqlens + a KV cache. + # Covers MLA (block_table + kv_cache + headdim_qk/headdim_v) and + # PagedAttention (page_table + k_cache/v_cache + symmetric headdim). + if ( + re.search(r"block_table|page_table", text) + and re.search(r"cache_seqlen", text) + and re.search(r"kv_cache|k_cache|v_cache", text) + ): + return True + if re.search(r"MultiHeadLatent|\bmla\b|PagedAttention", source_path.stem, re.IGNORECASE): + return True + return False + + +def _is_manual_softmax_attention(text: str) -> bool: + """Detect attention written manually with ``matmul`` + ``softmax``. + + Kernels like MHA/SDPA compute ``softmax(Q @ K^T) @ V`` using ``torch.matmul`` + (or ``bmm``) with a transposed operand rather than + ``F.scaled_dot_product_attention``, ``nn.MultiheadAttention``, or the ``@`` + operator, so they slip past the generic ``attention`` regexes. Require BOTH a + matmul against a transposed operand (the ``Q @ K^T`` score step) AND a softmax, + so plain GEMM kernels that merely transpose an operand are not misflagged. + """ + has_qkt = bool(re.search(r"(?:matmul|bmm)\s*\([^)]*\.(?:transpose|mT|permute)", text)) + has_softmax = bool(re.search(r"softmax", text)) + return has_qkt and has_softmax + + def detect_kernel_categories(source_path: Path) -> list[str]: """Detect kernel categories by pattern matching the source file.""" try: @@ -129,6 +179,19 @@ def detect_kernel_categories(source_path: Path) -> list[str]: for cat, patterns in _CATEGORY_PATTERNS.items(): if any(re.search(p, text) for p in patterns): categories.append(cat) + # Paged-decode attention kernels (MLA, PagedAttention) don't match the generic + # ``attention`` regexes (they use ``torch.matmul`` + paged-cache logic, not + # SDPA/MHA/@-transpose), so detect them explicitly and load the attention KB + # (which contains the § Decode Attention section) plus the gemm KB (Split-K + # GEMM, needed for the decomposed path). + if _is_paged_decode_attention(text, source_path): + for cat in ("attention", "gemm"): + if cat not in categories: + categories.append(cat) + # Manually-implemented attention (MHA, SDPA: softmax(Q@K^T)@V via torch.matmul + + # transpose) also misses the generic ``attention`` regexes — force the KB. + if _is_manual_softmax_attention(text) and "attention" not in categories: + categories.append("attention") if "attention" in categories and "reductions" not in categories: categories.append("reductions") return categories @@ -219,7 +282,7 @@ def load_translation_kb( - Translation guide (PyTorch op mapping, structural patterns, pitfalls) - Category-specific guides (reductions, GEMM, attention) """ - kb_root = get_data_dir("skills") / "pytorch2flydsl-translation" / "docs" + kb_root = Path(__file__).resolve().parents[3] / "skills" / "pytorch2flydsl-translation" / "docs" native_pure_root = kb_root / "native-pure" native_root = kb_root / "native" sections: list[str] = []