Commit fbe12b9
gasoonjia
[gemma4_31b][cuda] length-aware bf16 global attention + head_dim-agnostic prefill autotune
- Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar
(CUDA-graph-safe) instead of the full max_seq_len KV buffer -> O(context)
decode; restores decode scaling (was flat ~36.5 t/s at all depths ->
46.5@512, 34.9@127K). (sdpa.py kv_len path + cuda_source_transformations.py
_lenaware_attention_forward; global layers only, sliding + turbo untouched)
- Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a
head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune
(BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register spill
at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k; head_dim-agnostic
(no split-D needed for D<=512). (sdpa.py)
- Runner: add --tokens_file (pre-tokenized input) and --ignore_eos (fixed decode
length) for benchmarking. (main.cpp)
- Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512),
no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and
turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is
near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates).1 parent 1c371e2 commit fbe12b9
3 files changed
Lines changed: 367 additions & 141 deletions
File tree
- backends/cuda/triton/kernels
- examples/models/gemma4_31b
0 commit comments