Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions docs/fa4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# FlashAttention-4 (FA4) Support

## Overview

This PR adds FlashAttention-4 support to torchtitan across all model architectures (Llama3, Llama4, Qwen3, DeepSeek V3).

FA4 is a drop-in replacement for SDPA with the same input/output interface. It provides throughput gains that scale with sequence length, particularly beneficial for long-context training.

## Usage

Install FA4:

```bash
pip install fa4
```

Set `attn_type = "fa4"` in your model flavor or config. For example, Qwen3 30B-A3B:

```python
"30B-A3B-fa4": Qwen3ModelArgs(
...
attn_type="fa4",
)
```

No other changes are needed — FA4 handles causal masking and GQA automatically.

## Benchmark Results

**Qwen3 30B-A3B MoE, 8x NVIDIA B200 (180GB), bf16, TP=8, 20 steps on c4_test**

| Seq Len | Batch Size | SDPA TPS | FA4 TPS | Speedup | SDPA TFLOPS | FA4 TFLOPS |
|---------|------------|----------|----------|-----------|-------------|------------|
| 4k | 2 | 6,611 | 6,612 | +0.0% | — | — |
| 8k | 2 | 7,718 | 7,872 | +2.0% | 290.0 | 295.8 |
| 16k | 1 | 7,146 | 7,257 | +1.6% | 406.7 | 413.0 |
| 32k | 1 | 6,400 | 6,685 | **+4.5%** | 611.6 | 638.8 |

Memory usage is identical between SDPA and FA4 at all sequence lengths.

FA4's advantage grows with sequence length — attention becomes a larger fraction of total compute at longer sequences, so a faster attention kernel has more impact.

## Limitations

- **Context Parallel (CP):** FA4 is not compatible with CP. CP works by intercepting the attention kernel dispatch for ring-attention coordination. FA4 uses its own CUDA kernel that PyTorch's CP dispatcher doesn't support. FA4 + CP will raise `NotImplementedError`.
- **Requires `fa4` package:** The import is lazy (inside `forward()`), so users without FA4 installed won't see errors unless they set `attn_type="fa4"`.

## Implementation Details

- `FlashAttention4Wrapper` in `torchtitan/models/attention.py` handles the tensor layout conversion between torchtitan's `(batch, nheads, seqlen, headdim)` and FA4's `(batch, seqlen, nheads, headdim)` format.
- FA4 is wired into all model architectures via `case "fa4"` in `__init__` and shares the same forward branch as SDPA (`case "sdpa" | "fa4"`).
- CP guards in all model families explicitly block `fa4` alongside `varlen`.
44 changes: 44 additions & 0 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


__all__ = [
"FlashAttention4Wrapper",
"FlexAttentionWrapper",
"ScaledDotProductAttentionWrapper",
"VarlenAttentionWrapper",
Expand All @@ -51,6 +52,49 @@ class VarlenMetadata(NamedTuple):
max_k: Number


class FlashAttention4Wrapper(torch.nn.Module):
"""Wrapper around FlashAttention-4 (fa4) to make it compatible with torchtitan.

FA4 expects tensors in (batch, seqlen, nheads, headdim) layout, while
torchtitan models produce (batch, nheads, seqlen, headdim) after the
standard transpose. This wrapper handles the layout conversion.

Install FA4 with: pip install fa4
"""

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
scale: float | None = None,
enable_gqa: bool = False,
is_casual: bool = True,
deterministic: bool = False,
) -> torch.Tensor:
from flash_attn.cute import flash_attn_func

# Convert from (bs, n_heads, seqlen, head_dim) to (bs, seqlen, n_heads, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out = flash_attn_func(
q,
k,
v,
softmax_scale=scale,
causal=is_casual,
deterministic=deterministic,
)
if isinstance(out, tuple):
out = out[0]

# Convert back to (bs, n_heads, seqlen, head_dim)
return out.transpose(1, 2)


class VarlenAttentionWrapper(torch.nn.Module):
_compiled_varlen_attn: ClassVar[Callable] = torch.compile(
varlen_attn, mode="max-autotune-no-cudagraphs"
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def parallelize_deepseekv3(
"""

attn_type = getattr(model.model_args, "attn_type", "sdpa")
if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen":
if job_config.parallelism.context_parallel_degree > 1 and attn_type in ("varlen", "fa4"):
raise NotImplementedError(
f"Context Parallel only supports SDPA and FlexAttention."
f"Context Parallel only supports SDPA and FlexAttention. "
f"Got attn_type='{attn_type}'. "
f"Varlen attention is not supported with CP."
f"Varlen and FA4 attention are not supported with CP."
)

if parallel_dims.tp_enabled:
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchtitan.config.job_config import PEFT
from torchtitan.models.attention import (
create_attention_mask,
FlashAttention4Wrapper,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
Expand Down Expand Up @@ -289,6 +290,8 @@ def yarn_get_mscale(scale: float, mscale: float) -> float:
self.inner_attention = FlexAttentionWrapper()
case "sdpa":
self.inner_attention = ScaledDotProductAttentionWrapper()
case "fa4":
self.inner_attention = FlashAttention4Wrapper()
case "varlen":
raise ValueError("Varlen attention is not supported with Deepseek V3.")
case _:
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/models/llama3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:

if (
job_config.parallelism.context_parallel_degree > 1
and self.attn_type == "varlen"
and self.attn_type in ("varlen", "fa4")
):
raise NotImplementedError(
f"Context Parallel only supports SDPA and FlexAttention."
f"Context Parallel only supports SDPA and FlexAttention. "
f"Got attn_type='{self.attn_type}'. "
f"Varlen attention is not supported with CP."
f"Varlen and FA4 attention are not supported with CP."
)

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
create_attention_mask,
create_varlen_metadata_for_document,
create_varlen_metadata_from_sequence_lengths,
FlashAttention4Wrapper,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
Expand Down Expand Up @@ -253,6 +254,8 @@ def __init__(self, model_args: TransformerModelArgs, peft_config: PEFT):
self.inner_attention = VarlenAttentionWrapper()
case "sdpa":
self.inner_attention = ScaledDotProductAttentionWrapper()
case "fa4":
self.inner_attention = FlashAttention4Wrapper()
case _:
raise ValueError(f"Unknown attention type: {self.attn_type}")

Expand Down Expand Up @@ -329,7 +332,7 @@ def forward(
xv, # (bs, n_kv_heads, seqlen, head_dim)
attention_masks,
)
case "sdpa":
case "sdpa" | "fa4":
assert attention_masks is None
output = (
self.inner_attention(
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/models/llama4/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtitan.config.job_config import PEFT
from torchtitan.models.attention import (
create_attention_mask,
FlashAttention4Wrapper,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
Expand Down Expand Up @@ -244,6 +245,8 @@ def __init__(
self.inner_attention = FlexAttentionWrapper()
case "sdpa":
self.inner_attention = ScaledDotProductAttentionWrapper()
case "fa4":
self.inner_attention = FlashAttention4Wrapper()
case "varlen":
raise ValueError("Varlen attention is not supported with Llama 4.")
case _:
Expand Down Expand Up @@ -303,7 +306,7 @@ def forward(
enable_gqa=self.enable_gqa,
)
else:
assert attention_masks is None and self.attn_type == "sdpa"
assert attention_masks is None and self.attn_type in ("sdpa", "fa4")
output = self.inner_attention(
xq, # (bs, n_local_heads, seqlen, head_dim)
xk, # (bs, n_kv_heads, seqlen, head_dim)
Expand Down
49 changes: 49 additions & 0 deletions torchtitan/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,55 @@
attn_type="varlen",
attn_mask_type="block_causal",
),
"30B-A3B-fa4": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=262144,
head_dim=128,
dim=2048,
n_layers=48,
n_heads=32,
n_kv_heads=4,
qk_norm=True,
hidden_dim=6144,
rope_theta=1000000,
moe_enabled=True,
moe_inter_dim=768,
moe_args=MoEArgs(
num_experts=128,
num_shared_experts=0,
top_k=8,
score_func="softmax",
route_norm=True,
route_scale=1.0,
score_before_experts=False,
),
attn_type="fa4",
),
"30B-A3B-flex-causal": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=262144,
head_dim=128,
dim=2048,
n_layers=48,
n_heads=32,
n_kv_heads=4,
qk_norm=True,
hidden_dim=6144,
rope_theta=1000000,
moe_enabled=True,
moe_inter_dim=768,
moe_args=MoEArgs(
num_experts=128,
num_shared_experts=0,
top_k=8,
score_func="softmax",
route_norm=True,
route_scale=1.0,
score_before_experts=False,
),
attn_type="flex",
attn_mask_type="causal",
),
"30B-A3B-flex": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=262144,
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def parallelize_qwen3(
"""

attn_type = getattr(model.model_args, "attn_type", "sdpa")
if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen":
if job_config.parallelism.context_parallel_degree > 1 and attn_type in ("varlen", "fa4"):
raise NotImplementedError(
f"Context Parallel only supports SDPA and FlexAttention."
f"Context Parallel only supports SDPA and FlexAttention. "
f"Got attn_type='{attn_type}'. "
f"Varlen attention is not supported with CP."
f"Varlen and FA4 attention are not supported with CP."
)

model_compile_enabled = (
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/models/qwen3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
create_attention_mask,
create_varlen_metadata_for_document,
create_varlen_metadata_from_sequence_lengths,
FlashAttention4Wrapper,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
Expand Down Expand Up @@ -221,6 +222,8 @@ def __init__(self, model_args: Qwen3ModelArgs, peft_config: PEFT):
self.inner_attention = VarlenAttentionWrapper()
case "sdpa":
self.inner_attention = ScaledDotProductAttentionWrapper()
case "fa4":
self.inner_attention = FlashAttention4Wrapper()
case _:
raise ValueError(f"Unknown attention type: {self.attn_type}")

Expand Down Expand Up @@ -302,7 +305,7 @@ def forward(
attention_masks,
scale=self.scaling,
)
case "sdpa":
case "sdpa" | "fa4":
assert attention_masks is None
output = (
self.inner_attention(
Expand Down