Skip to content

Commit

Permalink
copy in triangular multiplicative module from another personal repo a…
Browse files Browse the repository at this point in the history
…nd add typechecks
  • Loading branch information
lucidrains committed May 13, 2024
1 parent 7b99949 commit 06f8690
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.nn.functional as F
from torch.nn import Module, Linear, Sequential

from typing import Literal

from alphafold3_pytorch.typing import (
Float,
Int,
Expand All @@ -16,6 +18,8 @@

from alphafold3_pytorch.attention import Attention

from einops import rearrange, einsum

# constants

LinearNoBias = partial(Linear, bias = False)
Expand Down Expand Up @@ -166,6 +170,76 @@ def forward(
beta = self.to_adaln_zero_beta(cond)
return out * gamma + beta

# triangle multiplicative module
# seems to be unchanged from alphafold2

class TriangleMultiplicativeModule(Module):

@typecheck
def __init__(
self,
*,
dim,
dim_hidden = None,
mix: Literal["ingoing", "outgoing"] = 'ingoing'
):
super().__init__()

dim_hidden = default(dim_hidden, dim)
self.norm = nn.LayerNorm(dim)

self.left_proj = nn.Linear(dim, dim_hidden)
self.right_proj = nn.Linear(dim, dim_hidden)

self.left_gate = nn.Linear(dim, dim_hidden)
self.right_gate = nn.Linear(dim, dim_hidden)
self.out_gate = nn.Linear(dim, dim_hidden)

# initialize all gating to be identity

for gate in (self.left_gate, self.right_gate, self.out_gate):
nn.init.constant_(gate.weight, 0.)
nn.init.constant_(gate.bias, 1.)

if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
elif mix == 'ingoing':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'

self.to_out_norm = nn.LayerNorm(dim_hidden)
self.to_out = nn.Linear(dim_hidden, dim)

@typecheck
def forward(
self,
x: Float['b n n d'],
mask: Float['b n n'] | None = None
):
if exists(mask):
mask = rearrange(mask, '... -> ... 1')

x = self.norm(x)

left = self.left_proj(x)
right = self.right_proj(x)

if exists(mask):
left = left * mask
right = right * mask

left_gate = self.left_gate(x).sigmoid()
right_gate = self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()

left = left * left_gate
right = right * right_gate

out = einsum(left, right, self.mix_einsum_eq)

out = self.to_out_norm(out)
out = out * out_gate
return self.to_out(out)

# main class

class Alphafold3(Module):
Expand Down

0 comments on commit 06f8690

Please sign in to comment.