Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import transformer_engine_torch as tex
Expand Down Expand Up @@ -55,6 +56,7 @@
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
FlashAttentionUtils,
)

from transformer_engine.pytorch.attention.dot_product_attention.backends import (
Expand Down Expand Up @@ -180,6 +182,26 @@
__all__ = ["DotProductAttention"]


def _pad_qkv_head_dim(query_layer, key_layer, value_layer):
"""Pad Q/K/V to the same head dimension for FlashAttention 2 MLA."""
orig_head_dim_qk = query_layer.shape[-1]
orig_head_dim_v = value_layer.shape[-1]
padded_head_dim = max(orig_head_dim_qk, orig_head_dim_v)
if orig_head_dim_qk < padded_head_dim:
query_layer = F.pad(query_layer, (0, padded_head_dim - orig_head_dim_qk))
key_layer = F.pad(key_layer, (0, padded_head_dim - key_layer.shape[-1]))
if orig_head_dim_v < padded_head_dim:
value_layer = F.pad(value_layer, (0, padded_head_dim - orig_head_dim_v))
return query_layer, key_layer, value_layer, orig_head_dim_qk, orig_head_dim_v


def _trim_output(attn_out, num_attention_heads, padded_head_dim_v, orig_head_dim_v):
"""Trim FlashAttention output after padding V to a larger head dimension."""
out_shape = attn_out.shape[:-1]
attn_out = attn_out.reshape(*out_shape, num_attention_heads, padded_head_dim_v)
return attn_out[..., :orig_head_dim_v].reshape(*out_shape, -1)


class DotProductAttention(TransformerEngineBaseModule):
r"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
Expand Down Expand Up @@ -1630,6 +1652,21 @@ def forward(
)

if use_flash_attention:
orig_qk_dim = None
orig_v_dim = None
if (
flash_attention_backend == FlashAttentionUtils.version
and not isinstance(value_layer, Float8TensorStorage)
and head_dim_qk != head_dim_v
):
(
query_layer,
key_layer,
value_layer,
orig_qk_dim,
orig_v_dim,
) = _pad_qkv_head_dim(query_layer, key_layer, value_layer)

Comment thread
HollowMan6 marked this conversation as resolved.
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
_alibi_cache,
Expand All @@ -1638,7 +1675,7 @@ def forward(
max_seqlen_kv,
alibi_slopes=alibi_slopes,
)
return self.flash_attention(
attn_out = self.flash_attention(
query_layer,
key_layer,
value_layer,
Expand Down Expand Up @@ -1666,6 +1703,9 @@ def forward(
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if orig_qk_dim is not None and orig_qk_dim > orig_v_dim:
return _trim_output(attn_out, num_attention_heads, orig_qk_dim, orig_v_dim)
return attn_out

if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,6 @@ def get_attention_backend(

# Filter: Head dimension
if head_dim_qk != head_dim_v:
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
use_flash_attention_2 = False

qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
Comment thread
HollowMan6 marked this conversation as resolved.
Expand All @@ -747,18 +743,27 @@ def get_attention_backend(
)
use_fused_attention = False

fa2_padded_head_dim = max(head_dim_qk, head_dim_v)
if ( # pylint: disable=too-many-boolean-expressions
use_flash_attention_2
and FlashAttentionUtils.is_installed
and (head_dim_qk > 256 or head_dim_qk % 8 != 0)
and (
fa2_padded_head_dim > 256
or fa2_padded_head_dim % 8 != 0
or (
fa2_padded_head_dim > 192
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
)
):
logger.debug(
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256. "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
"Supported after padding: padded head_dim %%8 = 0, padded head_dim <= 256 "
"(>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, padded head_dim = %s, on sm%s.",
head_dim_qk,
head_dim_v,
fa2_padded_head_dim,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention_2 = False
Expand Down
Loading