Skip to content

Commit

Permalink
complete hyper connected alphafold3
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 29, 2024
1 parent e18a330 commit c42884e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 10 additions & 6 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@

from colt5_attention import ConditionalRoutedAttention

from hyper_connections import HyperConnections
from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections

# other external libs

Expand Down Expand Up @@ -995,8 +995,8 @@ def __init__(
@typecheck
def forward(
self,
*,
pairwise_repr: Float['b n n d'],
*,
mask: Bool['b n'] | None = None,
value_residuals: tuple[Tensor, Tensor] | None = None,
return_values = False,
Expand Down Expand Up @@ -1470,8 +1470,8 @@ def __init__(
single_transition = Transition(dim = dim_single)

layers.append(ModuleList([
pairwise_block,
init_hyper_conn(dim = dim_single, branch = single_pre_ln(pair_bias_attn)),
init_hyper_conn(dim = dim_pairwise, branch = pairwise_block),
init_hyper_conn(dim = dim_single, additional_input_paths = [('pairwise_repr', dim_pairwise)], branch = single_pre_ln(pair_bias_attn)),
init_hyper_conn(dim = dim_single, branch = single_pre_ln(single_transition)),
]))

Expand Down Expand Up @@ -1508,6 +1508,7 @@ def to_layers(
) -> Tuple[Float['b n ds'], Float['b n n dp']]:

single_repr = self.expand_streams(single_repr)
pairwise_repr = self.expand_streams(pairwise_repr)

for _ in range(self.recurrent_depth):

Expand All @@ -1520,7 +1521,7 @@ def to_layers(
single_transition
) in self.layers:

pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr = pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True)
pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True)

single_repr, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual)

Expand All @@ -1531,6 +1532,7 @@ def to_layers(
single_repr = single_transition(single_repr)

single_repr = self.reduce_streams(single_repr)
pairwise_repr = self.reduce_streams(pairwise_repr)

return single_repr, pairwise_repr

Expand All @@ -1548,7 +1550,7 @@ def pairwise_block_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
pairwise_repr, pairwise_attn_values = layer(pairwise_repr = pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True)
pairwise_repr, pairwise_attn_values = layer(pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True)

if self.add_value_residual:
maybe_pairwise_value_residuals = default(maybe_pairwise_value_residuals, pairwise_attn_values)
Expand Down Expand Up @@ -1589,6 +1591,7 @@ def inner(inputs, *args, **kwargs):
wrapped_layers.append(single_transition_wrapper(single_transition))

single_repr = self.expand_streams(single_repr)
pairwise_repr = self.expand_streams(pairwise_repr)

for _ in range(self.recurrent_depth):
inputs = (single_repr, pairwise_repr, mask, None, None)
Expand All @@ -1599,6 +1602,7 @@ def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, *_ = inputs

single_repr = self.reduce_streams(single_repr)
pairwise_repr = self.reduce_streams(pairwise_repr)

return single_repr, pairwise_repr

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.7.7"
version = "0.7.8"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down Expand Up @@ -41,7 +41,7 @@ dependencies = [
"fair-esm",
"fastapi",
"frame-averaging-pytorch>=0.0.18",
"hyper-connections>=0.0.21",
"hyper-connections>=0.0.23",
"gradio",
"gradio_molecule3d",
"huggingface_hub>=0.21.4",
Expand Down

0 comments on commit c42884e

Please sign in to comment.