Skip to content

Commit

Permalink
complete the pairformer stack
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 13, 2024
1 parent 715327c commit c1fda49
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 26 deletions.
4 changes: 4 additions & 0 deletions alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
PreLayerNorm,
AdaptiveLayerNorm,
ConditionWrapper,
TriangleMultiplication,
AttentionPairBias,
TriangleAttention,
Transition,
PairformerStack,
Alphafold3
)

Expand All @@ -19,8 +21,10 @@
PreLayerNorm,
AdaptiveLayerNorm,
ConditionWrapper,
TriangleMultiplication,
AttentionPairBias,
TriangleAttention,
Transition,
PairformerStack,
Alphafold3
]
156 changes: 132 additions & 24 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Module, Linear, Sequential

from typing import Literal
from torch.nn import (
Module,
ModuleList,
Linear,
Sequential,
)

from typing import Literal, Tuple

from alphafold3_pytorch.typing import (
Float,
Expand All @@ -33,6 +39,12 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def pack_one(t, pattern):
return pack([t], pattern)

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

# classic feedforward, SwiGLU variant
# they name this 'transition' in their paper
# Algorithm 11
Expand Down Expand Up @@ -78,7 +90,7 @@ class PreLayerNorm(Module):
@typecheck
def __init__(
self,
fn: Attention | Transition | TriangleAttention | AttentionPairBias,
fn: Attention | Transition | TriangleAttention | TriangleMultiplication | AttentionPairBias,
*,
dim,
):
Expand Down Expand Up @@ -110,7 +122,7 @@ def __init__(
self.norm_cond = nn.LayerNorm(dim_cond, bias = False)

self.to_gamma = nn.Sequential(
nn.Linear(dim_cond, dim),
Linear(dim_cond, dim),
nn.Sigmoid()
)

Expand All @@ -119,9 +131,9 @@ def __init__(
@typecheck
def forward(
self,
x: Float['... n d'],
cond: Tensor
) -> Float['... n d']:
x: Float['b ... n d'],
cond: Float['b ... dc']
) -> Float['b ... n d']:

normed = self.norm(x)
normed_cond = self.norm_cond(cond)
Expand All @@ -146,7 +158,7 @@ def __init__(
self.fn = fn
self.adaptive_norm = AdaptiveLayerNorm(dim = dim, dim_cond = dim_cond)

adaln_zero_gamma_linear = nn.Linear(dim_cond, dim)
adaln_zero_gamma_linear = Linear(dim_cond, dim)
nn.init.zeros_(adaln_zero_gamma_linear.weight)
nn.init.constant_(adaln_zero_gamma_linear.bias, adaln_zero_bias_init_value)

Expand All @@ -158,11 +170,11 @@ def __init__(
@typecheck
def forward(
self,
x: Float['... n d'],
x: Float['b ... n d'],
*,
cond: Tensor,
cond: Float['b ... dc'],
**kwargs
) -> Float['... n d']:
) -> Float['b ... n d']:
x = self.adaptive_norm(x, cond = cond)

out = self.fn(x, **kwargs)
Expand All @@ -173,27 +185,27 @@ def forward(
# triangle multiplicative module
# seems to be unchanged from alphafold2

class TriangleMultiplicativeModule(Module):
class TriangleMultiplication(Module):

@typecheck
def __init__(
self,
*,
dim,
dim_hidden = None,
mix: Literal["ingoing", "outgoing"] = 'ingoing'
mix: Literal["incoming", "outgoing"] = 'incoming'
):
super().__init__()

dim_hidden = default(dim_hidden, dim)
self.norm = nn.LayerNorm(dim)

self.left_proj = nn.Linear(dim, dim_hidden)
self.right_proj = nn.Linear(dim, dim_hidden)
self.left_proj = Linear(dim, dim_hidden)
self.right_proj = Linear(dim, dim_hidden)

self.left_gate = nn.Linear(dim, dim_hidden)
self.right_gate = nn.Linear(dim, dim_hidden)
self.out_gate = nn.Linear(dim, dim_hidden)
self.left_gate = Linear(dim, dim_hidden)
self.right_gate = Linear(dim, dim_hidden)
self.out_gate = Linear(dim, dim_hidden)

# initialize all gating to be identity

Expand All @@ -203,21 +215,21 @@ def __init__(

if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
elif mix == 'ingoing':
elif mix == 'incoming':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'

self.to_out_norm = nn.LayerNorm(dim_hidden)
self.to_out = nn.Linear(dim_hidden, dim)
self.to_out = Linear(dim_hidden, dim)

@typecheck
def forward(
self,
x: Float['b n n d'],
mask: Float['b n n'] | None = None
mask: Bool['b n'] | None = None
) -> Float['b n n d']:

if exists(mask):
mask = rearrange(mask, '... -> ... 1')
mask = rearrange(mask, 'b i -> b i 1 1') & rearrange(mask, 'b j -> b 1 j 1')

x = self.norm(x)

Expand Down Expand Up @@ -331,7 +343,7 @@ def forward(
if exists(mask):
mask = repeat(mask, 'b ... -> (b r) ...', r = batch_repeat)

pairwise_repr, packed_shape = pack([pairwise_repr], '* n d')
pairwise_repr, packed_shape = pack_one(pairwise_repr, '* n d')

out = self.attn(
pairwise_repr,
Expand All @@ -340,13 +352,109 @@ def forward(
**kwargs
)

out, = unpack(out, packed_shape, '* n d')
out = unpack_one(out, packed_shape, '* n d')

if self.need_transpose:
out = rearrange(out, 'b j i d -> b i j d')

return out

# pairformer stack

class PairformerStack(Module):
def __init__(
self,
*,
dim_single,
dim_pairwise,
depth = 48,
tri_mult_dim_hidden = None,
tri_attn_dim_head = 32,
tri_attn_heads = 4,
pair_bias_attn_dim_head = 64,
pair_bias_attn_heads = 16,
dropout_row_prob = 0.25,
dropout_col_prob = 0.25
):
super().__init__()
layers = ModuleList([])

tri_mult_kwargs = dict(
dim = dim_pairwise,
dim_hidden = tri_mult_dim_hidden
)

tri_attn_kwargs = dict(
dim = dim_pairwise,
heads = tri_attn_heads,
dim_head = tri_attn_dim_head
)

pair_bias_attn_kwargs = dict(
dim = dim_single,
dim_pairwise_repr = dim_pairwise,
heads = pair_bias_attn_heads,
dim_head = pair_bias_attn_dim_head
)

for _ in range(depth):
dropout_row = nn.Dropout(dropout_row_prob)
dropout_col = nn.Dropout(dropout_col_prob)

pairwise_pre_ln = partial(PreLayerNorm, dim = dim_pairwise)

tri_mult_outgoing = TriangleMultiplication(mix = 'outgoing', **tri_mult_kwargs)
tri_mult_incoming = TriangleMultiplication(mix = 'incoming', **tri_mult_kwargs)
tri_attn_starting = TriangleAttention(node_type = 'starting', **tri_attn_kwargs)
tri_attn_ending = TriangleAttention(node_type = 'ending', **tri_attn_kwargs)

pairwise_transition = Transition(dim = dim_pairwise)

single_pre_ln = partial(PreLayerNorm, dim = dim_single)

pair_bias_attn = AttentionPairBias(**pair_bias_attn_kwargs)
single_transition = Transition(dim = dim_single)

layers.append(ModuleList([
dropout_row,
dropout_col,
pairwise_pre_ln(tri_mult_outgoing),
pairwise_pre_ln(tri_mult_incoming),
pairwise_pre_ln(tri_attn_starting),
pairwise_pre_ln(tri_attn_ending),
pairwise_pre_ln(pairwise_transition),
single_pre_ln(pair_bias_attn),
single_pre_ln(single_transition),
]))

self.layers = layers

@typecheck
def forward(
self,
*,
single_repr: Float['b n ds'],
pairwise_repr: Float['b n n dp'],
mask: Bool['b n'] | None = None

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

for dropout_row, dropout_col, tri_mult_outgoing, tri_mult_incoming, tri_attn_starting, tri_attn_ending, pairwise_transition, pair_bias_attn, single_transition in self.layers:

pairwise_repr = dropout_row(tri_mult_outgoing(pairwise_repr, mask = mask)) + pairwise_repr
pairwise_repr = dropout_row(tri_mult_incoming(pairwise_repr, mask = mask)) + pairwise_repr
pairwise_repr = dropout_row(tri_attn_starting(pairwise_repr, mask = mask)) + pairwise_repr
pairwise_repr = dropout_col(tri_attn_ending(pairwise_repr, mask = mask)) + pairwise_repr

pairwise_repr, packed_shape = pack_one(pairwise_repr, 'b * d')
pairwise_repr = pairwise_transition(pairwise_repr) + pairwise_repr
pairwise_repr = unpack_one(pairwise_repr, packed_shape, 'b * d')

single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
single_repr = single_transition(single_repr) + single_repr

return single_repr, pairwise_repr

# main class

class Alphafold3(Module):
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ test = [
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.pytest.ini_options]
pythonpath = [
"."
]

[tool.ruff]
line-length = 1000

Expand Down
27 changes: 25 additions & 2 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
import torch
import pytest

def test_readme():
assert True
from alphafold3_pytorch import (
PairformerStack
)

def test_pairformer():
single = torch.randn(2, 16, 512)
pairwise = torch.randn(2, 16, 16, 256)
mask = torch.randint(0, 2, (2, 16)).bool()

pairformer = PairformerStack(
depth = 4,
dim_single = 512,
dim_pairwise = 256
)

single_out, pairwise_out = pairformer(
single_repr = single,
pairwise_repr = pairwise,
mask = mask
)

assert single.shape == single_out.shape
assert pairwise.shape == pairwise_out.shape

0 comments on commit c1fda49

Please sign in to comment.