Skip to content

Commit c3273cb

Browse files
committed
[PyTorch] Pad V when Q/V head dims differ (MLA) for THD
Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 82ace62 commit c3273cb

2 files changed

Lines changed: 27 additions & 5 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212

1313
import torch
14+
import torch.nn.functional as F
1415
from torch.nn.parameter import Parameter
1516

1617
import transformer_engine_torch as tex
@@ -179,6 +180,18 @@
179180
__all__ = ["DotProductAttention"]
180181

181182

183+
def _pad_thd_value_layer(value_layer, head_dim_qk):
184+
"""Pad V for THD FlashAttention when Q/K and V head dimensions differ."""
185+
orig_head_dim_v = value_layer.shape[-1]
186+
return F.pad(value_layer, (0, head_dim_qk - orig_head_dim_v)), orig_head_dim_v
187+
188+
189+
def _trim_thd_output(attn_out, num_attention_heads, padded_head_dim_v, orig_head_dim_v):
190+
"""Trim FlashAttention THD output after padding V to the Q/K head dimension."""
191+
attn_out = attn_out.reshape(attn_out.shape[0], num_attention_heads, padded_head_dim_v)
192+
return attn_out[..., :orig_head_dim_v].reshape(attn_out.shape[0], -1)
193+
194+
182195
class DotProductAttention(TransformerEngineBaseModule):
183196
r"""Allows the model to jointly attend to information from different
184197
representation subspaces as described in the paper:
@@ -1508,6 +1521,16 @@ def forward(
15081521
)
15091522

15101523
if use_flash_attention:
1524+
orig_v_dim = None
1525+
if (
1526+
q_format == "thd"
1527+
and kv_format == "thd"
1528+
and not isinstance(value_layer, Float8TensorStorage)
1529+
and head_dim_qk != head_dim_v
1530+
and value_layer.shape[-1] < head_dim_qk
1531+
):
1532+
value_layer, orig_v_dim = _pad_thd_value_layer(value_layer, head_dim_qk)
1533+
15111534
if core_attention_bias_type == "alibi":
15121535
alibi_slopes, _ = dpa_utils.get_alibi(
15131536
_alibi_cache,
@@ -1516,7 +1539,7 @@ def forward(
15161539
max_seqlen_kv,
15171540
alibi_slopes=alibi_slopes,
15181541
)
1519-
return self.flash_attention(
1542+
attn_out = self.flash_attention(
15201543
query_layer,
15211544
key_layer,
15221545
value_layer,
@@ -1541,6 +1564,9 @@ def forward(
15411564
fp8_output=fp8_output,
15421565
num_splits=num_splits,
15431566
)
1567+
if orig_v_dim is not None:
1568+
return _trim_thd_output(attn_out, num_attention_heads, head_dim_qk, orig_v_dim)
1569+
return attn_out
15441570

15451571
if use_fused_attention:
15461572
fu_core_attention_bias_type = core_attention_bias_type

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -709,10 +709,6 @@ def get_attention_backend(
709709

710710
# Filter: Head dimension
711711
if head_dim_qk != head_dim_v:
712-
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
713-
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
714-
use_flash_attention_2 = False
715-
716712
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
717713
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
718714
logger.debug(

0 commit comments

Comments
 (0)