Skip to content

Commit

Permalink
refactor and complete AttentionPairBias and TriangleAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 13, 2024
1 parent 4c3979d commit a1dc28f
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 40 deletions.
4 changes: 4 additions & 0 deletions alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
PreLayerNorm,
AdaptiveLayerNorm,
ConditionWrapper,
AttentionPairBias,
TriangleAttention,
Transition,
Alphafold3
)
Expand All @@ -17,6 +19,8 @@
PreLayerNorm,
AdaptiveLayerNorm,
ConditionWrapper,
AttentionPairBias,
TriangleAttention,
Transition,
Alphafold3
]
113 changes: 110 additions & 3 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from alphafold3_pytorch.attention import Attention

from einops import rearrange, einsum
from einops import rearrange, repeat, einsum, pack, unpack
from einops.layers.torch import Rearrange

# constants

Expand Down Expand Up @@ -77,7 +78,7 @@ class PreLayerNorm(Module):
@typecheck
def __init__(
self,
fn: Attention | Transition,
fn: Attention | Transition | TriangleAttention | AttentionPairBias,
*,
dim,
):
Expand Down Expand Up @@ -135,7 +136,7 @@ class ConditionWrapper(Module):
@typecheck
def __init__(
self,
fn: Attention | Transition,
fn: Attention | Transition | TriangleAttention | AttentionPairBias,
*,
dim,
dim_cond,
Expand Down Expand Up @@ -240,6 +241,112 @@ def forward(
out = out * out_gate
return self.to_out(out)

# there are two types of attention in this paper, triangle and attention-pair-bias
# they differ by how the attention bias is computed
# triangle is axial attention w/ itself projected for bias

class AttentionPairBias(Module):
def __init__(
self,
*,
heads,
dim_pairwise_repr,
max_seq_len = 16384,
**attn_kwargs
):
super().__init__()
self.attn = Attention(heads = heads, **attn_kwargs)

# line 8 of Algorithm 24

to_attn_bias_linear = nn.Linear(dim_pairwise_repr, heads, bias = False)
nn.init.zeros_(to_attn_bias_linear.weight)

self.to_attn_bias = nn.Sequential(
nn.LayerNorm(dim_pairwise_repr),
to_attn_bias_linear,
Rearrange('... i j h -> ... h i j')
)

self.max_seq_len = max_seq_len
self.attn_bias_bias = nn.Parameter(torch.zeros(max_seq_len, max_seq_len))

@typecheck
def forward(
self,
single_repr: Float['b n ds'],
*,
pairwise_repr: Float['b n n dp'],
**kwargs
) -> Float['b n ds']:

seq = single_repr.shape[1]
assert seq <= self.max_seq_len

attn_bias = self.to_attn_bias(pairwise_repr) + self.attn_bias_bias[:seq, :seq]

out = self.attn(
single_repr,
attn_bias = attn_bias,
**kwargs
)

return out

class TriangleAttention(Module):
def __init__(
self,
*,
dim,
heads,
node_type: Literal['starting', 'ending'],
**attn_kwargs
):
super().__init__()
self.need_transpose = node_type == 'ending'

self.attn = Attention(dim = dim, heads = heads, **attn_kwargs)

self.to_attn_bias = nn.Sequential(
nn.Linear(dim, heads, bias = False),
Rearrange('... i j h -> ... h i j')
)

@typecheck
def forward(
self,
pairwise_repr: Float['b n n d'],
mask: Bool['b n'] | None = None,
**kwargs
) -> Float['b n n d']:

if self.need_transpose:
pairwise_repr = rearrange(pairwise_repr, 'b i j d -> b j i d')

attn_bias = self.to_attn_bias(pairwise_repr)

batch_repeat = pairwise_repr.shape[1]
attn_bias = repeat(attn_bias, 'b ... -> (b r) ...', r = batch_repeat)

if exists(mask):
mask = repeat(mask, 'b ... -> (b r) ...', r = batch_repeat)

pairwise_repr, packed_shape = pack([pairwise_repr], '* n d')

out = self.attn(
pairwise_repr,
mask = mask,
attn_bias = attn_bias,
**kwargs
)

out, = unpack(out, packed_shape, '* n d')

if self.need_transpose:
out = rearrange(out, 'b j i d -> b i j d')

return out

# main class

class Alphafold3(Module):
Expand Down
39 changes: 2 additions & 37 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def __init__(
gate_output = True,
query_bias = True,
flash = True,
efficient_attn_config: Config = Config(True, True, True),
dim_pairwise_repr: int | None = None,
max_seq_len: int = 8192
efficient_attn_config: Config = Config(True, True, True)
):
super().__init__()
"""
Expand Down Expand Up @@ -87,34 +85,13 @@ def __init__(

self.to_gates = gate_linear

# for projecting features to attn bias

self.accept_feature_to_bias_attn = exists(dim_pairwise_repr)

if self.accept_feature_to_bias_attn:
self.max_seq_len = max_seq_len

# line 8 of Algorithm 24

to_attn_bias_linear = nn.Linear(dim_pairwise_repr, heads, bias = False)
nn.init.zeros_(to_attn_bias_linear.weight)

self.to_attn_bias = nn.Sequential(
nn.LayerNorm(dim_pairwise_repr),
to_attn_bias_linear,
Rearrange('... i j h -> ... h i j')
)

self.attn_bias_bias = nn.Parameter(torch.zeros(max_seq_len, max_seq_len))

@typecheck
def forward(
self,
seq: Float['b i d'],
mask: Bool['b n']| None = None,
context: Float['b j d'] | None = None,
attn_bias: Float['... i j'] | None = None,
input_to_bias_attn: Float['b i j e'] | None = None,
attn_bias: Float['... i j'] | None = None

) -> Float['b i d']:

Expand All @@ -125,18 +102,6 @@ def forward(

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

# inputs to project into attn bias - for alphafold3, pairwise rep

assert not (exists(input_to_bias_attn) ^ self.accept_feature_to_bias_attn), 'if passing in pairwise representation, must set dim_pairwise_repr on Attention.__init__'

if self.accept_feature_to_bias_attn:
i, j = q.shape[-2], k.shape[-2]

assert not exists(attn_bias)
assert i <= self.max_seq_len and j <= self.max_seq_len

attn_bias = self.to_attn_bias(input_to_bias_attn) + self.attn_bias_bias[:i, :j]

# attention

out = self.attend(
Expand Down

0 comments on commit a1dc28f

Please sign in to comment.