Date: 2026-03-15 Author: garrick99 Hardware: NVIDIA RTX 4090 (SM89), 64-core CPU, 32GB RAM Codebase: LiteAttention v0.4.0 + SM89 patches Target: Video model inference — outperform FA4
LiteAttention is now fully operational on SM89 (RTX 4090) with both forward skip list and a new backward tile skipping capability. On video inference workloads with temporal redundancy:
- 6.6x faster per-step than baseline FA3 (134us vs 889us)
- ~7.5x faster than FA4 on H100 (~1000us reference, cross-hardware comparison)
- Output quality: cos_sim 0.989 at threshold=-1.0 with 2% inter-step noise
- Convergence: Skip list settles in 2 steps
- Backward tile skipping (new): 2.95x kernel speedup at 87.5% skip rate
LiteAttention's hopper/ package defaults to SM90-only compilation. Getting it running on SM89 required:
LITE_ATTENTION_DISABLE_SM80=FALSE— enables SM80 kernel instantiations (compiled to SM89)LITE_ATTENTION_DISABLE_BACKWARD=FALSE— enables backward kernels (default: disabled)CC=gcc-13 CXX=g++-13— Ubuntu 25.10 ships gcc-15, incompatible with CUDA 12.4's nvccMAX_JOBS=4— CUTLASS template instantiations consume ~6-8GB RAM each; full parallelism OOMs on 32GBNVCC_PREPEND_FLAGS='--compiler-bindir=/usr/bin/gcc-13'— explicit host compiler for nvcc
-
5 missing INT8 skipable SM90 instantiation files —
flash_fwd_hdim{64,96,128,192,256}_int8_skipable_sm90.cu— the forward dispatch generates all dtype×skipable combinations but no.cufiles existed for INT8+skipable. Created and added tosetup.py. -
Autograd gradient count mismatch —
FlashAttnFunc.backwardreturned 17 gradients but the forward accepts 24 parameters. Fixed by adding 7Nonereturns. -
Build environment —
setup.pyalready had SM80→SM89 arch remapping (compute_89,code=sm_89) and SM80/SM89 tile size dispatch inflash_api.cpp. No kernel code changes needed for basic SM89 support.
- 129MB
.sowith SM80 fwd+bwd + SM90 fwd+bwd across all hdims (64/96/128/192/256), bf16, INT8 - Full forward and backward verified on RTX 4090
| Config | Forward Time | TFLOPS |
|---|---|---|
| B=1 S=1024 H=16 D=128 | 77 us | 111.0 |
| B=1 S=2048 H=16 D=128 | 241 us | 142.7 |
| B=1 S=4096 H=16 D=128 | 889 us | 154.6 |
| B=1 S=8192 H=16 D=128 | 3,460 us | 158.9 |
| B=2 S=4096 H=32 D=128 | 3,364 us | 163.4 |
Peak: 163.4 TFLOPS (49.5% of RTX 4090's 330 TFLOPS tensor peak).
These numbers match the original SM80 benchmark report exactly, confirming the port is correct.
Near-diagonal attention pattern (K ≈ Q + noise), 2% Gaussian perturbation per step, simulating temporal redundancy in video denoising.
B=1 S=4096 H=16 D=128, threshold=-1.0, 20 steps:
| Step | Time | Speedup | Cos Sim |
|---|---|---|---|
| 0 (warmup) | 5,788 us | 0.15x | 1.000000 |
| 1 (1st skip) | 871 us | 1.02x | 0.996298 |
| 2 (settled) | 134 us | 6.63x | 0.988814 |
| 3 | 133 us | 6.69x | 0.988801 |
| 4-19 | 132-137 us | 6.5-6.8x | 0.9886-0.9888 |
Settled per-step: 134 us = 6.6x faster than baseline FA3.
Reference: FA4 on H100 is ~1000us for this config. LiteAttention on RTX 4090 at 134us is ~7.5x faster than FA4 on H100. Note: this is a cross-hardware comparison (RTX 4090 vs H100) — the RTX 4090 is a consumer GPU at a fraction of H100's cost and power draw.
| Threshold | Settled Time | Speedup | Compute % |
|---|---|---|---|
| -0.5 | 134 us | 6.66x | ~0% |
| -1.0 | 133 us | 6.67x | ~0% |
| -2.0 | 132 us | 6.72x | ~0% |
| -5.0 | 155 us | 5.75x | ~0% |
| -8.0 | 897 us | 0.99x | 0.6% |
| -10.0 | 1,095 us | 0.81x | 1.0% |
Sweet spot: threshold between -2.0 and -5.0. At -8.0, the skip list detects very few tiles to skip and overhead dominates.
When nothing is skipped (worst case for LiteAttention), the overhead is:
| Config | Base | Skip-enabled | Overhead |
|---|---|---|---|
| S=1024 | 73 us | 91 us | +24.6% |
| S=2048 | 241 us | 284 us | +18.0% |
| S=4096 | 888 us | 1,073 us | +20.9% |
| S=8192 | 3,668 us | 4,169 us | +13.6% |
Overhead amortizes at larger sizes. Break-even: skip rate must exceed overhead % (~14-21%).
Added block-sparsity support to the SM80 backward kernel. This is a new capability — neither LiteAttention nor BackLite had backward skipping on SM80/SM89.
C++ changes (5 files):
mainloop_bwd_sm80.hpp— Addedis_m_block_active()inline bitmask check using__ldg(read-only L2 cache path). All three m_block iteration loops (causal masking, main, local masking) check the mask before callingbwd_step(). If the bit is 0, the entire tile's S=QK^T, dP=dO·V^T, dV, dK, dQ computations are skipped.mainloop_bwd_sm90_tma_gmma_ws.hpp— Interface compatibility (block_mask fields in Args/Params)flash_bwd_launch_template.h— Passes block_mask through to mainloopflash.h—block_mask_ptr+block_mask_num_wordsinFlash_bwd_paramsflash_api.cpp—block_maskoptional parameter onbwd()binding
Python changes:
flash_attn_interface.py—block_maskparameter flows through_flash_attn_backward()→ C++bwd()binding, stored on autograd context for automatic use
B=1 S=4096 H=16 D=128:
| Skip Rate | Time | Speedup |
|---|---|---|
| 0% (baseline) | 2,594 us | 1.00x |
| 25% | 2,279 us | 1.14x |
| 75% | 1,180 us | 2.20x |
| 87.5% | 894 us | 2.90x |
| 95% | 750 us | 3.46x |
B=1 S=8192 H=16 D=128:
| Skip Rate | Time | Speedup |
|---|---|---|
| 0% (baseline) | 8,928 us | 1.00x |
| 75% | 2,611 us | 3.42x |
| 87.5% | 2,436 us | 3.67x |
| 95% | 870 us | 10.27x |
B=2 S=4096 H=32 D=128:
| Skip Rate | Time | Speedup |
|---|---|---|
| 0% (baseline) | 9,278 us | 1.00x |
| 75% | 3,017 us | 3.08x |
| 87.5% | 1,856 us | 5.00x |
| 95% | 1,277 us | 7.27x |
Notable: the all-ones mask (0% skip) is faster than no mask on some configs — the __ldg bitmask check + branch is apparently generating tighter code than the unconditional loop.
B=1 S=4096 H=16 D=128, 87.5% backward skip:
| Metric | Baseline | With Skip | Speedup |
|---|---|---|---|
| Forward + Backward | 3,499 us | 1,654 us | 2.12x |
B=1 S=4096 H=16 D=128, per-step breakdown:
| Stage | Baseline | % of Total |
|---|---|---|
| Tensor clone | 66 us | 1.8% |
| Grad enable | 3 us | 0.1% |
| Forward | 920 us | 24.5% |
| Loss | 19 us | 0.5% |
| Backward | 2,635 us | 70.1% |
| Grad readout | 115 us | 3.1% |
| Total | 3,759 us |
The backward dominates at 70% of total training time. Backward tile skipping directly targets this.
- Forward overhead reduction — 15-20% overhead at 0% skip rate. Sources:
scores_max_prevcopy, separatereduce_max, threshold comparison,__any_synccross-warp coordination. Target: <5% overhead by fusing score tracking into existing softmax reduction. - Real video model integration — validate skip rates and quality on actual video diffusion model (e.g., wrapping existing attention layers with LiteAttention).
- SM120 (Blackwell) support — RTX 5090 available for testing. WSL2 setup pending reboot. SM80 kernels run on SM120 via forward compat; native SM120 kernels would be faster.
- Automatic backward bitmask generation — currently the bitmask is provided manually. Need to generate it from the forward's skip list or tile_stats automatically. Two approaches: (a) convert LiteAttention's write_list to bitmask, (b) add BackLite's tile_stats to the forward kernel.
- Pre-allocated persistent buffers — skip lists, bitmasks, and tile_stats should be allocated once and reused across steps. Zero-alloc hot path like VortexSTARK's ProverCache.
- 6 SM80 skip list bugs from the original code review (all fixed locally, not yet upstreamed)
- BackLite review — 6 bugs found including HIGH severity multi-word bitmask bug
cd LiteAttention
# Full build: SM80+SM90, forward+backward, INT8
LITE_ATTENTION_DISABLE_SM80=FALSE \
LITE_ATTENTION_DISABLE_BACKWARD=FALSE \
MAX_JOBS=4 \
CC=gcc-13 CXX=g++-13 \
NVCC_PREPEND_FLAGS='--compiler-bindir=/usr/bin/gcc-13' \
python setup.py build_ext --inplace
# Install
pip install -e . --no-build-isolation --no-depsBuild time: ~20-30 minutes with MAX_JOBS=4 on a high-core-count CPU.
| File | Change |
|---|---|
setup.py |
Added INT8 skipable SM90 instantiation sources |
hopper/instantiations/flash_fwd_hdim*_int8_skipable_sm90.cu |
5 new files |
hopper/_internal/flash_attn_interface.py |
Autograd gradient count fix (17→25), block_mask parameter, bwd_block_mask support |
hopper/_internal/cpp/mainloop_bwd_sm80.hpp |
is_m_block_active() + skip checks in all 3 m_block loops |
hopper/_internal/cpp/mainloop_bwd_sm90_tma_gmma_ws.hpp |
Interface compatibility (block_mask fields) |
hopper/_internal/cpp/flash_bwd_launch_template.h |
Pass block_mask to mainloop args |
hopper/_internal/cpp/flash.h |
block_mask_ptr + block_mask_num_words in Flash_bwd_params |
hopper/_internal/cpp/flash_api.cpp |
block_mask parameter on bwd() binding |