Skip to content

Commit

Permalink
add hyper connections to diffusion transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2024
1 parent 85c0de3 commit dcc4e82
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 21 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,14 @@ docker run -v .:/data --gpus all -it af3
url = {https://api.semanticscholar.org/CorpusID:273849947}
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
64 changes: 46 additions & 18 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,19 @@

from alphafold3_pytorch.utils.model_utils import distance_to_dgram

# personal libraries

from frame_averaging_pytorch import FrameAverage

from taylor_series_linear_attention import TaylorSeriesLinearAttn

from colt5_attention import ConditionalRoutedAttention

import einx
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
from hyper_connections import HyperConnections

from tqdm import tqdm
# other external libs

from tqdm import tqdm
from loguru import logger

from importlib.metadata import version
Expand All @@ -132,6 +133,12 @@
from Bio.PDB.Structure import Structure
from Bio.PDB.StructureBuilder import StructureBuilder

# einstein notation related

import einx
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange

"""
global ein notation:
Expand Down Expand Up @@ -2008,6 +2015,7 @@ def __init__(
use_linear_attn = False,
checkpoint = False,
add_value_residual = False,
num_residual_streams = 1,
linear_attn_kwargs = dict(
heads = 8,
dim_head = 16
Expand All @@ -2026,6 +2034,12 @@ def __init__(

dim_single_cond = default(dim_single_cond, dim)

# hyper connections

init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

# layers

layers = ModuleList([])

for i in range(depth):
Expand All @@ -2042,6 +2056,8 @@ def __init__(
**linear_attn_kwargs
)

linear_attn = init_hyper_conn(dim = dim, branch = linear_attn)

colt5_attn = None

if use_colt5_attn:
Expand All @@ -2051,6 +2067,8 @@ def __init__(
**colt5_attn_kwargs
)

colt5_attn = init_hyper_conn(dim = dim, branch = colt5_attn)

accept_value_residual = add_value_residual and not is_first

pair_bias_attn = AttentionPairBias(
Expand Down Expand Up @@ -2083,8 +2101,8 @@ def __init__(
layers.append(ModuleList([
linear_attn,
colt5_attn,
conditionable_pair_bias,
conditionable_transition
init_hyper_conn(dim = dim, branch = conditionable_pair_bias),
init_hyper_conn(dim = dim, branch = conditionable_transition)
]))

self.checkpoint = checkpoint
Expand Down Expand Up @@ -2112,24 +2130,21 @@ def to_checkpointed_serial_layers(
windowed_mask: Bool['b nw w (w*2)'] | None = None
):

inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)

wrapped_layers = []

def efficient_attn_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
noised_repr = fn(noised_repr, mask = mask) + noised_repr
noised_repr = fn(noised_repr, mask = mask)
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

def attn_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
attn_out, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)
noised_repr = attn_out + noised_repr
noised_repr, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)
Expand All @@ -2141,10 +2156,12 @@ def transition_wrapper(fn):
@wraps(fn)
def inner(inputs):
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr
noised_repr = fn(noised_repr, cond = single_repr)
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
return inner

# wrap layers

for linear_attn, colt5_attn, attn, transition in self.layers:

if exists(linear_attn):
Expand All @@ -2156,10 +2173,19 @@ def inner(inputs):
wrapped_layers.append(attn_wrapper(attn))
wrapped_layers.append(transition_wrapper(transition))

# forward

noised_repr = self.expand_streams(noised_repr)

inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)

for layer in wrapped_layers:
inputs = checkpoint(layer, inputs)

noised_repr, *_ = inputs

noised_repr = self.reduce_streams(noised_repr)

return noised_repr

@typecheck
Expand All @@ -2175,15 +2201,17 @@ def to_serial_layers(

value_residual = None

noised_repr = self.expand_streams(noised_repr)

for linear_attn, colt5_attn, attn, transition in self.layers:

if exists(linear_attn):
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr
noised_repr = linear_attn(noised_repr, mask = mask)

if exists(colt5_attn):
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr
noised_repr = colt5_attn(noised_repr, mask = mask)

attn_out, attn_values = attn(
noised_repr, attn_values = attn(
noised_repr,
cond = single_repr,
pairwise_repr = pairwise_repr,
Expand All @@ -2193,15 +2221,15 @@ def to_serial_layers(
value_residual = value_residual
)

noised_repr = noised_repr + attn_out

if self.add_value_residual:
value_residual = default(value_residual, attn_values)

noised_repr = transition(
noised_repr,
cond = single_repr
) + noised_repr
)

noised_repr = self.reduce_streams(noised_repr)

return noised_repr

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.7.2"
version = "0.7.3"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down Expand Up @@ -41,6 +41,7 @@ dependencies = [
"fair-esm",
"fastapi",
"frame-averaging-pytorch>=0.0.18",
"hyper-connections>=0.0.14",
"gradio",
"gradio_molecule3d",
"huggingface_hub>=0.21.4",
Expand Down
7 changes: 5 additions & 2 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,13 @@ def test_msa_module(
@pytest.mark.parametrize('use_linear_attn', (False, True))
@pytest.mark.parametrize('use_colt5_attn', (False, True))
@pytest.mark.parametrize('add_value_residual', (False, True))
@pytest.mark.parametrize('num_residual_streams', (1, 4))
def test_diffusion_transformer(
checkpoint,
use_linear_attn,
use_colt5_attn,
add_value_residual
add_value_residual,
num_residual_streams
):

single = torch.randn(2, 16, 384).requires_grad_()
Expand All @@ -389,7 +391,8 @@ def test_diffusion_transformer(
checkpoint = checkpoint,
use_linear_attn = use_linear_attn,
use_colt5_attn = use_colt5_attn,
add_value_residual = add_value_residual
add_value_residual = add_value_residual,
num_residual_streams = num_residual_streams
)

single_out = diffusion_transformer(
Expand Down

0 comments on commit dcc4e82

Please sign in to comment.