diff --git a/docs/fa4.md b/docs/fa4.md new file mode 100644 index 0000000000..6478678a7c --- /dev/null +++ b/docs/fa4.md @@ -0,0 +1,52 @@ +# FlashAttention-4 (FA4) Support + +## Overview + +This PR adds FlashAttention-4 support to torchtitan across all model architectures (Llama3, Llama4, Qwen3, DeepSeek V3). + +FA4 is a drop-in replacement for SDPA with the same input/output interface. It provides throughput gains that scale with sequence length, particularly beneficial for long-context training. + +## Usage + +Install FA4: + +```bash +pip install fa4 +``` + +Set `attn_type = "fa4"` in your model flavor or config. For example, Qwen3 30B-A3B: + +```python +"30B-A3B-fa4": Qwen3ModelArgs( + ... + attn_type="fa4", +) +``` + +No other changes are needed — FA4 handles causal masking and GQA automatically. + +## Benchmark Results + +**Qwen3 30B-A3B MoE, 8x NVIDIA B200 (180GB), bf16, TP=8, 20 steps on c4_test** + +| Seq Len | Batch Size | SDPA TPS | FA4 TPS | Speedup | SDPA TFLOPS | FA4 TFLOPS | +|---------|------------|----------|----------|-----------|-------------|------------| +| 4k | 2 | 6,611 | 6,612 | +0.0% | — | — | +| 8k | 2 | 7,718 | 7,872 | +2.0% | 290.0 | 295.8 | +| 16k | 1 | 7,146 | 7,257 | +1.6% | 406.7 | 413.0 | +| 32k | 1 | 6,400 | 6,685 | **+4.5%** | 611.6 | 638.8 | + +Memory usage is identical between SDPA and FA4 at all sequence lengths. + +FA4's advantage grows with sequence length — attention becomes a larger fraction of total compute at longer sequences, so a faster attention kernel has more impact. + +## Limitations + +- **Context Parallel (CP):** FA4 is not compatible with CP. CP works by intercepting the attention kernel dispatch for ring-attention coordination. FA4 uses its own CUDA kernel that PyTorch's CP dispatcher doesn't support. FA4 + CP will raise `NotImplementedError`. +- **Requires `fa4` package:** The import is lazy (inside `forward()`), so users without FA4 installed won't see errors unless they set `attn_type="fa4"`. + +## Implementation Details + +- `FlashAttention4Wrapper` in `torchtitan/models/attention.py` handles the tensor layout conversion between torchtitan's `(batch, nheads, seqlen, headdim)` and FA4's `(batch, seqlen, nheads, headdim)` format. +- FA4 is wired into all model architectures via `case "fa4"` in `__init__` and shares the same forward branch as SDPA (`case "sdpa" | "fa4"`). +- CP guards in all model families explicitly block `fa4` alongside `varlen`. diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 152da82530..8e6f432d86 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -25,6 +25,7 @@ __all__ = [ + "FlashAttention4Wrapper", "FlexAttentionWrapper", "ScaledDotProductAttentionWrapper", "VarlenAttentionWrapper", @@ -51,6 +52,49 @@ class VarlenMetadata(NamedTuple): max_k: Number +class FlashAttention4Wrapper(torch.nn.Module): + """Wrapper around FlashAttention-4 (fa4) to make it compatible with torchtitan. + + FA4 expects tensors in (batch, seqlen, nheads, headdim) layout, while + torchtitan models produce (batch, nheads, seqlen, headdim) after the + standard transpose. This wrapper handles the layout conversion. + + Install FA4 with: pip install fa4 + """ + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + enable_gqa: bool = False, + is_casual: bool = True, + deterministic: bool = False, + ) -> torch.Tensor: + from flash_attn.cute import flash_attn_func + + # Convert from (bs, n_heads, seqlen, head_dim) to (bs, seqlen, n_heads, head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = flash_attn_func( + q, + k, + v, + softmax_scale=scale, + causal=is_casual, + deterministic=deterministic, + ) + if isinstance(out, tuple): + out = out[0] + + # Convert back to (bs, n_heads, seqlen, head_dim) + return out.transpose(1, 2) + + class VarlenAttentionWrapper(torch.nn.Module): _compiled_varlen_attn: ClassVar[Callable] = torch.compile( varlen_attn, mode="max-autotune-no-cudagraphs" diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 05c345adab..e28e132d36 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -65,11 +65,11 @@ def parallelize_deepseekv3( """ attn_type = getattr(model.model_args, "attn_type", "sdpa") - if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen": + if job_config.parallelism.context_parallel_degree > 1 and attn_type in ("varlen", "fa4"): raise NotImplementedError( - f"Context Parallel only supports SDPA and FlexAttention." + f"Context Parallel only supports SDPA and FlexAttention. " f"Got attn_type='{attn_type}'. " - f"Varlen attention is not supported with CP." + f"Varlen and FA4 attention are not supported with CP." ) if parallel_dims.tp_enabled: diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 98b32a16a3..917cab18f9 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -16,6 +16,7 @@ from torchtitan.config.job_config import PEFT from torchtitan.models.attention import ( create_attention_mask, + FlashAttention4Wrapper, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, @@ -289,6 +290,8 @@ def yarn_get_mscale(scale: float, mscale: float) -> float: self.inner_attention = FlexAttentionWrapper() case "sdpa": self.inner_attention = ScaledDotProductAttentionWrapper() + case "fa4": + self.inner_attention = FlashAttention4Wrapper() case "varlen": raise ValueError("Varlen attention is not supported with Deepseek V3.") case _: diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 43f5c69dea..5bf819de44 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -64,12 +64,12 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: if ( job_config.parallelism.context_parallel_degree > 1 - and self.attn_type == "varlen" + and self.attn_type in ("varlen", "fa4") ): raise NotImplementedError( - f"Context Parallel only supports SDPA and FlexAttention." + f"Context Parallel only supports SDPA and FlexAttention. " f"Got attn_type='{self.attn_type}'. " - f"Varlen attention is not supported with CP." + f"Varlen and FA4 attention are not supported with CP." ) def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index e5b5fb2273..aafad1d430 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -21,6 +21,7 @@ create_attention_mask, create_varlen_metadata_for_document, create_varlen_metadata_from_sequence_lengths, + FlashAttention4Wrapper, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, @@ -253,6 +254,8 @@ def __init__(self, model_args: TransformerModelArgs, peft_config: PEFT): self.inner_attention = VarlenAttentionWrapper() case "sdpa": self.inner_attention = ScaledDotProductAttentionWrapper() + case "fa4": + self.inner_attention = FlashAttention4Wrapper() case _: raise ValueError(f"Unknown attention type: {self.attn_type}") @@ -329,7 +332,7 @@ def forward( xv, # (bs, n_kv_heads, seqlen, head_dim) attention_masks, ) - case "sdpa": + case "sdpa" | "fa4": assert attention_masks is None output = ( self.inner_attention( diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 4b4b56251e..0c52580c9a 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -17,6 +17,7 @@ from torchtitan.config.job_config import PEFT from torchtitan.models.attention import ( create_attention_mask, + FlashAttention4Wrapper, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, @@ -244,6 +245,8 @@ def __init__( self.inner_attention = FlexAttentionWrapper() case "sdpa": self.inner_attention = ScaledDotProductAttentionWrapper() + case "fa4": + self.inner_attention = FlashAttention4Wrapper() case "varlen": raise ValueError("Varlen attention is not supported with Llama 4.") case _: @@ -303,7 +306,7 @@ def forward( enable_gqa=self.enable_gqa, ) else: - assert attention_masks is None and self.attn_type == "sdpa" + assert attention_masks is None and self.attn_type in ("sdpa", "fa4") output = self.inner_attention( xq, # (bs, n_local_heads, seqlen, head_dim) xk, # (bs, n_kv_heads, seqlen, head_dim) diff --git a/torchtitan/models/qwen3/__init__.py b/torchtitan/models/qwen3/__init__.py index a359c40a03..dc2a9dd608 100644 --- a/torchtitan/models/qwen3/__init__.py +++ b/torchtitan/models/qwen3/__init__.py @@ -318,6 +318,55 @@ attn_type="varlen", attn_mask_type="block_causal", ), + "30B-A3B-fa4": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=262144, + head_dim=128, + dim=2048, + n_layers=48, + n_heads=32, + n_kv_heads=4, + qk_norm=True, + hidden_dim=6144, + rope_theta=1000000, + moe_enabled=True, + moe_inter_dim=768, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + route_scale=1.0, + score_before_experts=False, + ), + attn_type="fa4", + ), + "30B-A3B-flex-causal": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=262144, + head_dim=128, + dim=2048, + n_layers=48, + n_heads=32, + n_kv_heads=4, + qk_norm=True, + hidden_dim=6144, + rope_theta=1000000, + moe_enabled=True, + moe_inter_dim=768, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + route_scale=1.0, + score_before_experts=False, + ), + attn_type="flex", + attn_mask_type="causal", + ), "30B-A3B-flex": Qwen3ModelArgs( vocab_size=151936, max_seq_len=262144, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index cb93887e3e..0224a4b70d 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -65,11 +65,11 @@ def parallelize_qwen3( """ attn_type = getattr(model.model_args, "attn_type", "sdpa") - if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen": + if job_config.parallelism.context_parallel_degree > 1 and attn_type in ("varlen", "fa4"): raise NotImplementedError( - f"Context Parallel only supports SDPA and FlexAttention." + f"Context Parallel only supports SDPA and FlexAttention. " f"Got attn_type='{attn_type}'. " - f"Varlen attention is not supported with CP." + f"Varlen and FA4 attention are not supported with CP." ) model_compile_enabled = ( diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index a9b188039b..63f7e72adc 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -20,6 +20,7 @@ create_attention_mask, create_varlen_metadata_for_document, create_varlen_metadata_from_sequence_lengths, + FlashAttention4Wrapper, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, @@ -221,6 +222,8 @@ def __init__(self, model_args: Qwen3ModelArgs, peft_config: PEFT): self.inner_attention = VarlenAttentionWrapper() case "sdpa": self.inner_attention = ScaledDotProductAttentionWrapper() + case "fa4": + self.inner_attention = FlashAttention4Wrapper() case _: raise ValueError(f"Unknown attention type: {self.attn_type}") @@ -302,7 +305,7 @@ def forward( attention_masks, scale=self.scaling, ) - case "sdpa": + case "sdpa" | "fa4": assert attention_masks is None output = ( self.inner_attention(