Skip to content

Commit

Permalink
just allow Attention class to handle projection of pairwise represent…
Browse files Browse the repository at this point in the history
…ation into attention bias
  • Loading branch information
lucidrains committed May 13, 2024
1 parent 07c8a28 commit 7f923b1
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def max_neg_value(t):
# multi-head attention

class Attention(Module):
@typecheck
def __init__(
self,
*,
Expand All @@ -42,7 +43,9 @@ def __init__(
gate_output = True,
query_bias = True,
flash = True,
efficient_attn_config: Config = Config(True, True, True)
efficient_attn_config: Config = Config(True, True, True),
dim_pairwise_repr: int | None = None,
max_seq_len: int = 8192
):
super().__init__()
"""
Expand All @@ -52,6 +55,7 @@ def __init__(
h - heads
n - sequence
d - dimension
e - dimension (pairwise rep)
i - source sequence
j - context sequence
"""
Expand All @@ -71,7 +75,8 @@ def __init__(
self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False)
self.to_out = nn.Linear(dim_inner, dim, bias = False)

# used in alphafold2
# gating of value
# allows attention to attend to nothing

self.to_gates = None

Expand All @@ -82,13 +87,32 @@ def __init__(

self.to_gates = gate_linear

# for projecting features to attn bias

self.accept_feature_to_bias_attn = exists(dim_pairwise_repr)

if self.accept_feature_to_bias_attn:
self.max_seq_len = max_seq_len

# line 8 of Algorithm 24

self.to_attn_bias = nn.Sequential(
nn.LayerNorm(dim_pairwise_repr),
nn.Linear(dim_pairwise_repr, heads, bias = False),
Rearrange('... i j h -> ... h i j')
)

self.attn_bias_bias = nn.Parameter(torch.zeros(max_seq_len, max_seq_len))

@typecheck
def forward(
self,
seq: Float['b i d'],
mask: Bool['b n']| None = None,
attn_bias: Float['b h i j'] | None = None,
context: Float['b j d'] | None = None
context: Float['b j d'] | None = None,
attn_bias: Float['... i j'] | None = None,
input_to_bias_attn: Float['... i j e'] | None = None,

) -> Float['b i d']:

q = self.to_q(seq)
Expand All @@ -98,18 +122,38 @@ def forward(

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

# inputs to project into attn bias - for alphafold3, pairwise rep

assert not (exists(input_to_bias_attn) ^ self.accept_feature_to_bias_attn), 'if passing in pairwise representation, must set dim_pairwise_repr on Attention.__init__'

if self.accept_feature_to_bias_attn:
i, j = q.shape[-2], k.shape[-2]

assert not exists(attn_bias)
assert i <= self.max_seq_len and j <= self.max_seq_len

attn_bias = self.to_attn_bias(input_to_bias_attn) + self.attn_bias_bias[:i, :j]

# attention

out = self.attend(
q, k, v,
attn_bias = attn_bias,
mask = mask
)

# merge heads

out = self.merge_heads(out)

# gate output

if exists(self.to_gates):
gates = self.to_gates(seq)
out = out * gates.sigmoid()

# combine heads

return self.to_out(out)

# attending, both vanilla as well as in-built flash attention
Expand All @@ -124,12 +168,15 @@ def __init__(
):
super().__init__()
"""
ein notation
ein notation:
b - batch
h - heads
n - sequence
d - dimension
n, i, j - sequence (base sequence length, source, target)
e - dimension (pairwise rep)
i - source sequence
j - context sequence
"""

self.scale = scale
Expand Down Expand Up @@ -171,8 +218,8 @@ def forward(
q: Float['b h i d'],
k: Float['b h j d'],
v: Float['b h j d'],
attn_bias: Float['b h i j'] | None = None,
mask: Bool['b j'] | None = None
mask: Bool['b j'] | None = None,
attn_bias: Float['... i j'] | None = None,
) -> Float['b h i d']:

can_use_flash = self.flash and not exists(attn_bias), 'flash attention does not support attention bias with gradients'
Expand Down

0 comments on commit 7f923b1

Please sign in to comment.