Skip to content

Commit

Permalink
the pairwise representation biases the single representation attentio…
Browse files Browse the repository at this point in the history
…n matrix
  • Loading branch information
lucidrains committed May 13, 2024
1 parent b541b7f commit e74b28d
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def forward(
self,
seq: Float['b i d'],
mask: Bool['b n']| None = None,
attn_bias: Float['b h i j'] | None = None,
context: Float['b j d'] | None = None
) -> Float['b i d']:

Expand All @@ -96,7 +97,11 @@ def forward(

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

out = self.attend(q, k, v, mask = mask)
out = self.attend(
q, k, v,
attn_bias = attn_bias,
mask = mask
)

out = self.merge_heads(out)

Expand Down Expand Up @@ -165,10 +170,13 @@ def forward(
q: Float['b h i d'],
k: Float['b h j d'],
v: Float['b h j d'],
attn_bias: Float['b h i j'] | None = None,
mask: Bool['b j'] | None = None
) -> Float['b h i d']:

if self.flash:
can_use_flash = self.flash and not exists(attn_bias), 'flash attention does not support attention bias with gradients'

if can_use_flash:
return self.flash_attn(q, k, v, mask = mask)

scale = default(self.scale, q.shape[-1] ** -0.5)
Expand All @@ -179,6 +187,11 @@ def forward(

sim = einsum(q, k, "b h i d, b h j d -> b h i j")

# attn bias

if exists(attn_bias):
sim = sim + attn_bias

# masking

if exists(mask):
Expand Down

0 comments on commit e74b28d

Please sign in to comment.