1111import logging
1212
1313import torch
14+ import torch .nn .functional as F
1415from torch .nn .parameter import Parameter
1516
1617import transformer_engine_torch as tex
179180__all__ = ["DotProductAttention" ]
180181
181182
183+ def _pad_thd_value_layer (value_layer , head_dim_qk ):
184+ """Pad V for THD FlashAttention when Q/K and V head dimensions differ."""
185+ orig_head_dim_v = value_layer .shape [- 1 ]
186+ return F .pad (value_layer , (0 , head_dim_qk - orig_head_dim_v )), orig_head_dim_v
187+
188+
189+ def _trim_thd_output (attn_out , num_attention_heads , padded_head_dim_v , orig_head_dim_v ):
190+ """Trim FlashAttention THD output after padding V to the Q/K head dimension."""
191+ attn_out = attn_out .reshape (attn_out .shape [0 ], num_attention_heads , padded_head_dim_v )
192+ return attn_out [..., :orig_head_dim_v ].reshape (attn_out .shape [0 ], - 1 )
193+
194+
182195class DotProductAttention (TransformerEngineBaseModule ):
183196 r"""Allows the model to jointly attend to information from different
184197 representation subspaces as described in the paper:
@@ -1508,6 +1521,16 @@ def forward(
15081521 )
15091522
15101523 if use_flash_attention :
1524+ orig_v_dim = None
1525+ if (
1526+ q_format == "thd"
1527+ and kv_format == "thd"
1528+ and not isinstance (value_layer , Float8TensorStorage )
1529+ and head_dim_qk != head_dim_v
1530+ and value_layer .shape [- 1 ] < head_dim_qk
1531+ ):
1532+ value_layer , orig_v_dim = _pad_thd_value_layer (value_layer , head_dim_qk )
1533+
15111534 if core_attention_bias_type == "alibi" :
15121535 alibi_slopes , _ = dpa_utils .get_alibi (
15131536 _alibi_cache ,
@@ -1516,7 +1539,7 @@ def forward(
15161539 max_seqlen_kv ,
15171540 alibi_slopes = alibi_slopes ,
15181541 )
1519- return self .flash_attention (
1542+ attn_out = self .flash_attention (
15201543 query_layer ,
15211544 key_layer ,
15221545 value_layer ,
@@ -1541,6 +1564,9 @@ def forward(
15411564 fp8_output = fp8_output ,
15421565 num_splits = num_splits ,
15431566 )
1567+ if orig_v_dim is not None :
1568+ return _trim_thd_output (attn_out , num_attention_heads , head_dim_qk , orig_v_dim )
1569+ return attn_out
15441570
15451571 if use_fused_attention :
15461572 fu_core_attention_bias_type = core_attention_bias_type
0 commit comments