Skip to content

Commit

Permalink
upgrade value residual to learnt mixing per token / head
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2024
1 parent af6f972 commit 85c0de3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 39 deletions.
76 changes: 46 additions & 30 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def forward(
**kwargs
) -> (
Float['b n d'] |
tuple[Float['b n d'], Float['b _ _']]
tuple[Float['b n d'], Float['b _ _ _']]
):
x = self.adaptive_norm(x, cond = cond)

Expand Down Expand Up @@ -797,11 +797,11 @@ def forward(
pairwise_repr: Float["b n n dp"] | Float["b nw w (w*2) dp"], # type: ignore
attn_bias: Float["b n n"] | Float["b nw w (w*2)"] | None = None, # type: ignore
return_values: bool = False,
value_residual: Float['b _ _'] | None = None,
value_residual: Float['b _ _ _'] | None = None,
**kwargs,
) -> (
Float['b n ds'] |
tuple[Float['b n ds'], Float['b _ _']]
tuple[Float['b n ds'], Float['b _ _ _']]
): # type: ignore

"""Perform the forward pass.
Expand Down Expand Up @@ -961,6 +961,7 @@ def __init__(
tri_attn_heads = 4,
dropout_row_prob = 0.25,
dropout_col_prob = 0.25,
accept_value_residual = False
):
super().__init__()

Expand All @@ -974,7 +975,8 @@ def __init__(
tri_attn_kwargs = dict(
dim = dim_pairwise,
heads = tri_attn_heads,
dim_head = tri_attn_dim_head
dim_head = tri_attn_dim_head,
accept_value_residual = accept_value_residual
)

self.tri_mult_outgoing = pre_ln(TriangleMultiplication(mix = 'outgoing', dropout = dropout_row_prob, dropout_type = 'row', **tri_mult_kwargs))
Expand Down Expand Up @@ -1436,16 +1438,20 @@ def __init__(
**pair_bias_attn_kwargs
)

for _ in range(depth):
for i in range(depth):

is_first = i == 0
accept_value_residual = add_value_residual and not is_first

single_pre_ln = partial(PreLayerNorm, dim = dim_single)

pairwise_block = PairwiseBlock(
dim_pairwise = dim_pairwise,
accept_value_residual = accept_value_residual,
**pairwise_block_kwargs
)

pair_bias_attn = AttentionPairBias(**pair_bias_attn_kwargs)
pair_bias_attn = AttentionPairBias(accept_value_residual = accept_value_residual, **pair_bias_attn_kwargs)
single_transition = Transition(dim = dim_single)

layers.append(ModuleList([
Expand Down Expand Up @@ -1486,10 +1492,11 @@ def to_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

value_residual = None
pairwise_value_residuals = None

for _ in range(self.recurrent_depth):

value_residual = None
pairwise_value_residuals = None

for (
pairwise_block,
pair_bias_attn,
Expand Down Expand Up @@ -1520,54 +1527,59 @@ def to_checkpointed_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

inputs = (single_repr, pairwise_repr, mask, None)

def pairwise_block_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
return single_repr, pairwise_repr, mask, maybe_value_residual
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
pairwise_repr, pairwise_attn_values = layer(pairwise_repr = pairwise_repr, mask = mask, value_residuals = maybe_pairwise_value_residuals, return_values = True)

if self.add_value_residual:
maybe_pairwise_value_residuals = default(maybe_pairwise_value_residuals, pairwise_attn_values)

return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
return inner

def pair_bias_attn_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
single_repr = single_repr + attn_out

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)

return single_repr, pairwise_repr, mask, maybe_value_residual
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
return inner

def single_transition_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
single_repr = layer(single_repr) + single_repr
return single_repr, pairwise_repr, mask, maybe_value_residual
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
return inner

wrapped_layers = []

for (
pairwise_block,
pair_bias_attn,
single_transition
) in self.layers:

wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
wrapped_layers.append(single_transition_wrapper(single_transition))

for _ in range(self.recurrent_depth):
for (
pairwise_block,
pair_bias_attn,
single_transition
) in self.layers:
inputs = (single_repr, pairwise_repr, mask, None, None)

wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
wrapped_layers.append(single_transition_wrapper(single_transition))
for layer in wrapped_layers:
inputs = checkpoint(layer, inputs)

for layer in wrapped_layers:
inputs = checkpoint(layer, inputs)
single_repr, pairwise_repr, *_ = inputs

single_repr, pairwise_repr, *_ = inputs
return single_repr, pairwise_repr

@typecheck
Expand Down Expand Up @@ -2016,7 +2028,8 @@ def __init__(

layers = ModuleList([])

for _ in range(depth):
for i in range(depth):
is_first = i == 0

linear_attn = None

Expand All @@ -2038,12 +2051,15 @@ def __init__(
**colt5_attn_kwargs
)

accept_value_residual = add_value_residual and not is_first

pair_bias_attn = AttentionPairBias(
dim = dim,
dim_pairwise = dim_pairwise,
heads = heads,
window_size = attn_window_size,
num_memory_kv = attn_num_memory_kv,
accept_value_residual = accept_value_residual,
**attn_pair_bias_kwargs
)

Expand Down
32 changes: 24 additions & 8 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def __init__(
laser_softclamp_value = 15.,
enable_attn_softclamp = False,
attn_softclamp_value = 50.,
softmax_full_precision = False
softmax_full_precision = False,
accept_value_residual = False
):
super().__init__()
"""
Expand Down Expand Up @@ -237,6 +238,18 @@ def __init__(
if gate_output:
self.to_gates = nn.Sequential(LinearNoBias(dim, dim_inner), nn.Sigmoid())

# learned value residual mixing
# even greater improvements on top of value residual learning, discovered by open source community

self.accept_value_residual = accept_value_residual

if accept_value_residual:
self.to_value_residual_mix = nn.Sequential(
LinearNoBias(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
)

@typecheck
def forward(
self,
Expand All @@ -246,28 +259,31 @@ def forward(
windowed_mask: Bool['b nw w (w*2)'] | None = None,
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
return_values: bool = False,
value_residual: Float['b j dh'] | None = None,
value_residual: Float['b h j dh'] | None = None,

) -> (
Float['b i d'] |
tuple[Float['b i d'], Float['b j _']]
tuple[Float['b i d'], Float['b h j dh']]
):

q = self.to_q(seq)

context_seq = default(context, seq)
k, v = self.to_kv(context_seq).chunk(2, dim = -1)

# split heads

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))

# handle value residual

orig_v = v

if exists(value_residual):
v = 0.5 * (v + value_residual)
assert not (self.accept_value_residual ^ exists(value_residual))

# split heads

q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
if exists(value_residual):
mix = self.to_value_residual_mix(seq)
v = v.lerp(value_residual, mix)

# attention

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.7.0"
version = "0.7.2"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit 85c0de3

Please sign in to comment.