Skip to content

garrick99/liteattention-sm89

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1,061 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

LiteAttention SM89 Port & Performance Report

Date: 2026-03-15 Author: garrick99 Hardware: NVIDIA RTX 4090 (SM89), 64-core CPU, 32GB RAM Codebase: LiteAttention v0.4.0 + SM89 patches Target: Video model inference — outperform FA4


Executive Summary

LiteAttention is now fully operational on SM89 (RTX 4090) with both forward skip list and a new backward tile skipping capability. On video inference workloads with temporal redundancy:

  • 6.6x faster per-step than baseline FA3 (134us vs 889us)
  • ~7.5x faster than FA4 on H100 (~1000us reference, cross-hardware comparison)
  • Output quality: cos_sim 0.989 at threshold=-1.0 with 2% inter-step noise
  • Convergence: Skip list settles in 2 steps
  • Backward tile skipping (new): 2.95x kernel speedup at 87.5% skip rate

1. SM89 Port — What Was Required

LiteAttention's hopper/ package defaults to SM90-only compilation. Getting it running on SM89 required:

Build Configuration

  • LITE_ATTENTION_DISABLE_SM80=FALSE — enables SM80 kernel instantiations (compiled to SM89)
  • LITE_ATTENTION_DISABLE_BACKWARD=FALSE — enables backward kernels (default: disabled)
  • CC=gcc-13 CXX=g++-13 — Ubuntu 25.10 ships gcc-15, incompatible with CUDA 12.4's nvcc
  • MAX_JOBS=4 — CUTLASS template instantiations consume ~6-8GB RAM each; full parallelism OOMs on 32GB
  • NVCC_PREPEND_FLAGS='--compiler-bindir=/usr/bin/gcc-13' — explicit host compiler for nvcc

Code Fixes Required

  1. 5 missing INT8 skipable SM90 instantiation filesflash_fwd_hdim{64,96,128,192,256}_int8_skipable_sm90.cu — the forward dispatch generates all dtype×skipable combinations but no .cu files existed for INT8+skipable. Created and added to setup.py.

  2. Autograd gradient count mismatchFlashAttnFunc.backward returned 17 gradients but the forward accepts 24 parameters. Fixed by adding 7 None returns.

  3. Build environmentsetup.py already had SM80→SM89 arch remapping (compute_89,code=sm_89) and SM80/SM89 tile size dispatch in flash_api.cpp. No kernel code changes needed for basic SM89 support.

Build Output

  • 129MB .so with SM80 fwd+bwd + SM90 fwd+bwd across all hdims (64/96/128/192/256), bf16, INT8
  • Full forward and backward verified on RTX 4090

2. Inference Performance (Forward Only)

Baseline — FA3 (no skip list)

Config Forward Time TFLOPS
B=1 S=1024 H=16 D=128 77 us 111.0
B=1 S=2048 H=16 D=128 241 us 142.7
B=1 S=4096 H=16 D=128 889 us 154.6
B=1 S=8192 H=16 D=128 3,460 us 158.9
B=2 S=4096 H=32 D=128 3,364 us 163.4

Peak: 163.4 TFLOPS (49.5% of RTX 4090's 330 TFLOPS tensor peak).

These numbers match the original SM80 benchmark report exactly, confirming the port is correct.

Skip List — Video Inference Simulation

Near-diagonal attention pattern (K ≈ Q + noise), 2% Gaussian perturbation per step, simulating temporal redundancy in video denoising.

B=1 S=4096 H=16 D=128, threshold=-1.0, 20 steps:

Step Time Speedup Cos Sim
0 (warmup) 5,788 us 0.15x 1.000000
1 (1st skip) 871 us 1.02x 0.996298
2 (settled) 134 us 6.63x 0.988814
3 133 us 6.69x 0.988801
4-19 132-137 us 6.5-6.8x 0.9886-0.9888

Settled per-step: 134 us = 6.6x faster than baseline FA3.

Reference: FA4 on H100 is ~1000us for this config. LiteAttention on RTX 4090 at 134us is ~7.5x faster than FA4 on H100. Note: this is a cross-hardware comparison (RTX 4090 vs H100) — the RTX 4090 is a consumer GPU at a fraction of H100's cost and power draw.

Threshold Sweep

Threshold Settled Time Speedup Compute %
-0.5 134 us 6.66x ~0%
-1.0 133 us 6.67x ~0%
-2.0 132 us 6.72x ~0%
-5.0 155 us 5.75x ~0%
-8.0 897 us 0.99x 0.6%
-10.0 1,095 us 0.81x 1.0%

Sweet spot: threshold between -2.0 and -5.0. At -8.0, the skip list detects very few tiles to skip and overhead dominates.

Skip List Overhead (0% skip rate)

When nothing is skipped (worst case for LiteAttention), the overhead is:

Config Base Skip-enabled Overhead
S=1024 73 us 91 us +24.6%
S=2048 241 us 284 us +18.0%
S=4096 888 us 1,073 us +20.9%
S=8192 3,668 us 4,169 us +13.6%

Overhead amortizes at larger sizes. Break-even: skip rate must exceed overhead % (~14-21%).


3. Backward Tile Skipping (New Feature)

Added block-sparsity support to the SM80 backward kernel. This is a new capability — neither LiteAttention nor BackLite had backward skipping on SM80/SM89.

Implementation

C++ changes (5 files):

  • mainloop_bwd_sm80.hpp — Added is_m_block_active() inline bitmask check using __ldg (read-only L2 cache path). All three m_block iteration loops (causal masking, main, local masking) check the mask before calling bwd_step(). If the bit is 0, the entire tile's S=QK^T, dP=dO·V^T, dV, dK, dQ computations are skipped.
  • mainloop_bwd_sm90_tma_gmma_ws.hpp — Interface compatibility (block_mask fields in Args/Params)
  • flash_bwd_launch_template.h — Passes block_mask through to mainloop
  • flash.hblock_mask_ptr + block_mask_num_words in Flash_bwd_params
  • flash_api.cppblock_mask optional parameter on bwd() binding

Python changes:

  • flash_attn_interface.pyblock_mask parameter flows through _flash_attn_backward() → C++ bwd() binding, stored on autograd context for automatic use

Raw Backward Kernel Performance

B=1 S=4096 H=16 D=128:

Skip Rate Time Speedup
0% (baseline) 2,594 us 1.00x
25% 2,279 us 1.14x
75% 1,180 us 2.20x
87.5% 894 us 2.90x
95% 750 us 3.46x

B=1 S=8192 H=16 D=128:

Skip Rate Time Speedup
0% (baseline) 8,928 us 1.00x
75% 2,611 us 3.42x
87.5% 2,436 us 3.67x
95% 870 us 10.27x

B=2 S=4096 H=32 D=128:

Skip Rate Time Speedup
0% (baseline) 9,278 us 1.00x
75% 3,017 us 3.08x
87.5% 1,856 us 5.00x
95% 1,277 us 7.27x

Notable: the all-ones mask (0% skip) is faster than no mask on some configs — the __ldg bitmask check + branch is apparently generating tighter code than the unconditional loop.

End-to-End Autograd

B=1 S=4096 H=16 D=128, 87.5% backward skip:

Metric Baseline With Skip Speedup
Forward + Backward 3,499 us 1,654 us 2.12x

4. Pipeline Profiling — Where Time Is Spent

B=1 S=4096 H=16 D=128, per-step breakdown:

Stage Baseline % of Total
Tensor clone 66 us 1.8%
Grad enable 3 us 0.1%
Forward 920 us 24.5%
Loss 19 us 0.5%
Backward 2,635 us 70.1%
Grad readout 115 us 3.1%
Total 3,759 us

The backward dominates at 70% of total training time. Backward tile skipping directly targets this.


5. Remaining Work

Inference (Omer's priority)

  1. Forward overhead reduction — 15-20% overhead at 0% skip rate. Sources: scores_max_prev copy, separate reduce_max, threshold comparison, __any_sync cross-warp coordination. Target: <5% overhead by fusing score tracking into existing softmax reduction.
  2. Real video model integration — validate skip rates and quality on actual video diffusion model (e.g., wrapping existing attention layers with LiteAttention).
  3. SM120 (Blackwell) support — RTX 5090 available for testing. WSL2 setup pending reboot. SM80 kernels run on SM120 via forward compat; native SM120 kernels would be faster.

Training

  1. Automatic backward bitmask generation — currently the bitmask is provided manually. Need to generate it from the forward's skip list or tile_stats automatically. Two approaches: (a) convert LiteAttention's write_list to bitmask, (b) add BackLite's tile_stats to the forward kernel.
  2. Pre-allocated persistent buffers — skip lists, bitmasks, and tile_stats should be allocated once and reused across steps. Zero-alloc hot path like VortexSTARK's ProverCache.

Bug fixes (upstream)

  1. 6 SM80 skip list bugs from the original code review (all fixed locally, not yet upstreamed)
  2. BackLite review — 6 bugs found including HIGH severity multi-word bitmask bug

6. Build Instructions

cd LiteAttention

# Full build: SM80+SM90, forward+backward, INT8
LITE_ATTENTION_DISABLE_SM80=FALSE \
LITE_ATTENTION_DISABLE_BACKWARD=FALSE \
MAX_JOBS=4 \
CC=gcc-13 CXX=g++-13 \
NVCC_PREPEND_FLAGS='--compiler-bindir=/usr/bin/gcc-13' \
python setup.py build_ext --inplace

# Install
pip install -e . --no-build-isolation --no-deps

Build time: ~20-30 minutes with MAX_JOBS=4 on a high-core-count CPU.


7. Files Modified (vs upstream v0.4.0)

File Change
setup.py Added INT8 skipable SM90 instantiation sources
hopper/instantiations/flash_fwd_hdim*_int8_skipable_sm90.cu 5 new files
hopper/_internal/flash_attn_interface.py Autograd gradient count fix (17→25), block_mask parameter, bwd_block_mask support
hopper/_internal/cpp/mainloop_bwd_sm80.hpp is_m_block_active() + skip checks in all 3 m_block loops
hopper/_internal/cpp/mainloop_bwd_sm90_tma_gmma_ws.hpp Interface compatibility (block_mask fields)
hopper/_internal/cpp/flash_bwd_launch_template.h Pass block_mask to mainloop args
hopper/_internal/cpp/flash.h block_mask_ptr + block_mask_num_words in Flash_bwd_params
hopper/_internal/cpp/flash_api.cpp block_mask parameter on bwd() binding

About

LiteAttention SM89 port: 6.6x inference speedup + backward tile skipping

Resources

License

Unknown and 2 other licenses found

Licenses found

Unknown
LICENSE
BSD-3-Clause
LICENSE-BSD
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors