Skip to content

Commit

Permalink
complete the msa pair weighted averaging module
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 14, 2024
1 parent 53867a0 commit 8de21e3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
2 changes: 2 additions & 0 deletions alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AdaptiveLayerNorm,
ConditionWrapper,
OuterProductMean,
MSAPairWeightedAveraging,
TriangleMultiplication,
AttentionPairBias,
TriangleAttention,
Expand All @@ -23,6 +24,7 @@
AdaptiveLayerNorm,
ConditionWrapper,
OuterProductMean,
MSAPairWeightedAveraging,
TriangleMultiplication,
AttentionPairBias,
TriangleAttention,
Expand Down
69 changes: 69 additions & 0 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def max_neg_value(t: Tensor):
return -torch.finfo(t.dtype).max

def pack_one(t, pattern):
return pack([t], pattern)

Expand Down Expand Up @@ -430,6 +433,72 @@ def forward(
pairwise_repr = self.to_pairwise_repr(outer_product_mean)
return pairwise_repr


class MSAPairWeightedAveraging(Module):
""" Algorithm 10 """

def __init__(
self,
*,
dim_msa,
dim_pairwise_repr,
dim_head = 32,
heads = 8
):
super().__init__()
dim_inner = dim_head * heads

self.msa_to_values_and_gates = nn.Sequential(
nn.LayerNorm(dim_msa),
LinearNoBias(dim_msa, dim_inner * 2),
Rearrange('b s n (gv h d) -> gv b h s n d', gv = 2, h = heads)
)

self.pairwise_repr_to_attn = nn.Sequential(
nn.LayerNorm(dim_pairwise_repr),
LinearNoBias(dim_pairwise_repr, heads),
Rearrange('b i j h -> b h i j')
)

self.to_out = nn.Sequential(
Rearrange('b h s n d -> b s n (h d)'),
LinearNoBias(dim_inner, dim_msa)
)

@typecheck
def forward(
self,
*,
msa: Float['b s n d'],
pairwise_repr: Float['b n n dp'],
mask: Bool['b n'] | None = None,
) -> Float['b s n d']:

values, gates = self.msa_to_values_and_gates(msa)
gates = gates.sigmoid()

# line 3

b = self.pairwise_repr_to_attn(pairwise_repr)

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
b = b.masked_fill(~mask, max_neg_value(b))

# line 5

weights = b.softmax(dim = -1)

# line 6

out = einsum(weights, values, 'b h i j, b h s j d -> b h s i d')

out = out * gates

# combine heads

return self.to_out(out)

class MSAModule(Module):
def __init__(
self,
Expand Down

0 comments on commit 8de21e3

Please sign in to comment.