Skip to content

Commit bad0028

Browse files
committed
fix(transformerblock): conditionally initialize cross attention components
Initialize cross attention layers only when with_cross_attention is True to avoid unnecessary computation and memory usage
1 parent 1e6c661 commit bad0028

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

monai/networks/blocks/transformerblock.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ def __init__(
7979
)
8080
self.norm2 = nn.LayerNorm(hidden_size)
8181
self.with_cross_attention = with_cross_attention
82-
83-
self.norm_cross_attn = nn.LayerNorm(hidden_size)
84-
self.cross_attn = CrossAttentionBlock(
85-
hidden_size=hidden_size,
86-
num_heads=num_heads,
87-
dropout_rate=dropout_rate,
88-
qkv_bias=qkv_bias,
89-
causal=False,
90-
use_flash_attention=use_flash_attention,
91-
)
82+
if with_cross_attention:
83+
self.norm_cross_attn = nn.LayerNorm(hidden_size)
84+
self.cross_attn = CrossAttentionBlock(
85+
hidden_size=hidden_size,
86+
num_heads=num_heads,
87+
dropout_rate=dropout_rate,
88+
qkv_bias=qkv_bias,
89+
causal=False,
90+
use_flash_attention=use_flash_attention,
91+
)
9292

9393
def forward(
9494
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None

0 commit comments

Comments
 (0)