Skip to content

Commit

Permalink
complete basic prelayernorm + adaptive layernorm as well as condition…
Browse files Browse the repository at this point in the history
… wrapper for both attention and feedforward with the adaln-zero
  • Loading branch information
lucidrains committed May 13, 2024
1 parent adaa8a2 commit 7b99949
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
6 changes: 6 additions & 0 deletions alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
)

from alphafold3_pytorch.alphafold3 import (
PreLayerNorm,
AdaptiveLayerNorm,
ConditionWrapper,
Transition,
Alphafold3
)

__all__ = [
Attention,
Attend,
PreLayerNorm,
AdaptiveLayerNorm,
ConditionWrapper,
Transition,
Alphafold3
]
105 changes: 105 additions & 0 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations
from functools import partial

import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Module, Linear, Sequential

Expand All @@ -11,6 +14,8 @@
typecheck
)

from alphafold3_pytorch.attention import Attention

# constants

LinearNoBias = partial(Linear, bias = False)
Expand Down Expand Up @@ -61,6 +66,106 @@ def forward(

return self.ff(x)

# normalization
# both pre layernorm as well as adaptive layernorm wrappers

class PreLayerNorm(Module):
@typecheck
def __init__(
self,
fn: Attention | Transition,
*,
dim,
):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

@typecheck
def forward(
self,
x: Tensor,
**kwargs
):
x = self.norm(x)
return self.fn(x, **kwargs)

class AdaptiveLayerNorm(Module):
""" Algorithm 26 """

def __init__(
self,
*,
dim,
dim_cond
):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine = False)
self.norm_cond = nn.LayerNorm(dim_cond, bias = False)

self.to_gamma = nn.Sequential(
nn.Linear(dim_cond, dim),
nn.Sigmoid()
)

self.to_beta = nn.Linear(dim_cond, dim, bias = False)

@typecheck
def forward(
self,
x: Tensor,
cond: Tensor
):
normed = self.norm(x)
normed_cond = self.norm_cond(cond)

gamma = self.to_gamma(normed_cond)
beta = self.to_beta(normed_cond)
return normed * gamma + beta

class ConditionWrapper(Module):
""" Algorithm 25 """

@typecheck
def __init__(
self,
fn: Attention | Transition,
*,
dim,
dim_cond,
adaln_zero_bias_init_value = -2.
):
super().__init__()
self.fn = fn
self.adaptive_norm = AdaptiveLayerNorm(dim = dim, dim_cond = dim_cond)

adaln_zero_gamma_linear = nn.Linear(dim_cond, dim)
nn.init.zeros_(adaln_zero_gamma_linear.weight)
nn.init.constant_(adaln_zero_gamma_linear.bias, adaln_zero_bias_init_value)

self.to_adaln_zero_gamma = nn.Sequential(
adaln_zero_gamma_linear,
nn.Sigmoid()
)

self.to_adaln_zero_beta = nn.Linear(dim_cond, dim, bias = False)

@typecheck
def forward(
self,
x: Tensor,
*,
cond: Tensor,
**kwargs
):
x = self.adaptive_norm(x, cond = cond)

out = self.fn(x, **kwargs)

gamma = self.to_adaln_zero_gamma(cond)
beta = self.to_adaln_zero_beta(cond)
return out * gamma + beta

# main class

class Alphafold3(Module):
Expand Down

0 comments on commit 7b99949

Please sign in to comment.