Skip to content

Add SDPA fallback for sliding-window attention#161

Open
fayerman-source wants to merge 1 commit intokarpathy:masterfrom
fayerman-source:upstream/sdpa-mask-cache
Open

Add SDPA fallback for sliding-window attention#161
fayerman-source wants to merge 1 commit intokarpathy:masterfrom
fayerman-source:upstream/sdpa-mask-cache

Conversation

@fayerman-source
Copy link

Summary

  • use Flash Attention 3 only on Hopper and fall back to PyTorch SDPA elsewhere
  • precompute and reuse SDPA sliding-window masks instead of rebuilding them every forward pass
  • keep the full-context path on is_causal=True when no explicit sliding-window mask is needed

Why

The current non-Hopper path still routes through FA3, which fails on RTX 5070-class hardware. This keeps the Hopper path unchanged, but makes the same model runnable on non-Hopper GPUs and avoids repeated T x T mask construction on the SDPA path.

This was motivated by the non-Hopper discussion in #36 and the SDPA-focused follow-up in #108.

Benchmark note

RTX 5070, CUDA SDPA microbenchmark, batch 8, seq 2048, 4-layer test model:

variant tok/s ms/fwd
cached SDPA mask 1,248,774 13.12
rebuild mask each forward 1,222,867 13.40

I also smoke-tested a tiny GPU forward pass on the SDPA path after this change.

@karpathy
Copy link
Owner

i don't think i'll merge this (too bloating), but i will leave the PR up.

from kernels import get_kernel
cap = torch.cuda.get_device_capability()
# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Ada/Ampere, it's still nice to use "kernels-community/flash-attn3". I had some improvements using it on an A100.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants